j40-cejst-2/data/data-pipeline/data_pipeline/tests/sources/example/test_etl.py
Lucas Merrill Brown 3e37d9d1a3
Issue 1075: update snapshots using command-line flag (#1249)
* Adding skippable tests using command-line flag
2022-02-14 12:16:52 -05:00

570 lines
22 KiB
Python

# pylint: disable=protected-access,unsubscriptable-object
import copy
import os
import pathlib
import pytest
import numpy as np
import pandas as pd
from data_pipeline.etl.base import ExtractTransformLoad, ValidGeoLevel
from data_pipeline.tests.conftest import copy_data_files
from data_pipeline.tests.sources.example.etl import ExampleETL
from data_pipeline.utils import get_module_logger
logger = get_module_logger(__name__)
class TestETL:
"""A base class that can be inherited by all other ETL tests.
Note: every method that does *not* need to be reimplemented by child classes has
the test name pattern of `test_*_base`. All other tests need to be reimplemented.
"""
# In every child test class, change this to the class of the ETL being tested.
_ETL_CLASS = ExampleETL
# The following constants do not need to be updated in child class.
_INPUT_CSV_FILE_NAME = "input.csv"
_TRANSFORM_CSV_FILE_NAME = "transform.csv"
_OUTPUT_CSV_FILE_NAME = "output.csv"
# Note: We used shared census tract IDs so that later our tests can join all the
# ETL results together and generate a test score. This join is only possible if
# we use the same tract IDs across fixtures.
# The test fixtures may also contain other tract IDs that are not on this list.
_FIXTURES_SHARED_TRACT_IDS = [
"06007040300",
"06001020100",
"06007040500",
"15001021010",
"15001021101",
"15007040603",
"15007040700",
"15009030100",
"15009030201",
"15001021402",
"15001021800",
"15009030402",
"15009030800",
"15003010201",
"15007040604",
]
_DATA_DIRECTORY_FOR_TEST: pathlib.PosixPath
def setup_method(self, _method, filename=__file__):
"""Before every test, set the data directory for the test.
Uses the directory of the test class to infer the data directory.
pytest does not support classes with an `__init__`. Instead, we use this
`setup_method` which pytest will run before every test method is run.
For now, all child classes inheriting this need to reimplement this, but can
use the same line of code regardless of the child class:
```
def setup_method(self, _method, filename=__file__):
'''Invoke `setup_method` from Parent, but using the current file name
This code can be copied identically between all child classes.
'''
super().setup_method(_method=_method, filename=filename)
```
"""
self._DATA_DIRECTORY_FOR_TEST = pathlib.Path(filename).parent / "data"
def _get_instance_of_etl_class(self) -> type(ExtractTransformLoad):
return self._ETL_CLASS()
def _setup_etl_instance_and_run_extract(
self, mock_etl, mock_paths
) -> ExtractTransformLoad:
"""Method to setup an ETL instance with proper upstream mocks to run extract.
This must be re-implemented in every child class.
This method can be used by multiple tests that need to run the same fixtures
that need these same mocks, and by `test_update_test_fixtures`.
In order to re-implement this method, usually it will involve a
decent amount of work to monkeypatch `requests` or another method that's
used to retrieve data in order to force that method to retrieve the fixture
data.
"""
# When running this in child classes, make sure the child class re-implements
# this method.
if self._ETL_CLASS is not ExampleETL:
raise NotImplementedError(
"Prepare and run extract method not defined for this class."
)
# The rest of this method applies for `ExampleETL` only.
etl = self._get_instance_of_etl_class()
etl.extract()
return etl
def test_existence_of_test_fixtures_base(self):
"""Every ETL test should have these two test fixture files.
Can be run without modification for all child classes.
"""
assert (
self._DATA_DIRECTORY_FOR_TEST / self._TRANSFORM_CSV_FILE_NAME
).exists()
assert (
self._DATA_DIRECTORY_FOR_TEST / self._OUTPUT_CSV_FILE_NAME
).exists()
def test_init_base(self, mock_etl, mock_paths):
"""Test whether class has appropriate parameters set.
Can be run without modification for all child classes.
"""
# Setup
etl = self._get_instance_of_etl_class()
data_path, tmp_path = mock_paths
assert etl.DATA_PATH == data_path
assert etl.TMP_PATH == tmp_path
# Also make sure all parameters that need to be non-null are non-null
assert etl.NAME is not None
assert etl.LAST_UPDATED_YEAR is not None
assert etl.SOURCE_URL is not None
assert etl.GEO_LEVEL is not None
assert etl.COLUMNS_TO_KEEP is not None
assert len(etl.COLUMNS_TO_KEEP) > 0
# No duplicate columns to keep
assert len(etl.COLUMNS_TO_KEEP) == len(set(etl.COLUMNS_TO_KEEP))
# Check certain parameters are set.
assert etl.EXPECTED_MAX_CENSUS_BLOCK_GROUPS == 250000
assert etl.EXPECTED_MAX_CENSUS_TRACTS == 74160
assert etl.EXPECTED_CENSUS_TRACTS_CHARACTER_LENGTH == 11
assert etl.EXPECTED_CENSUS_BLOCK_GROUPS_CHARACTER_LENGTH == 13
def test_get_output_file_path_base(self, mock_etl, mock_paths):
"""Test file path method.
Can be run without modification for all child classes.
"""
etl = self._get_instance_of_etl_class()
data_path, tmp_path = mock_paths
actual_file_path = etl._get_output_file_path()
expected_file_path = (
data_path
/ "dataset"
/ f"{etl.NAME}_{etl.LAST_UPDATED_YEAR}"
/ "usa.csv"
)
logger.info(f"Expected: {expected_file_path}")
assert actual_file_path == expected_file_path
def test_fixtures_contain_shared_tract_ids_base(self, mock_etl, mock_paths):
"""Check presence of necessary shared tract IDs.
Note: We used shared census tract IDs so that later our tests can join all the
ETL results together and generate a test score. This join is only possible if
we use the same tract IDs across fixtures.
Can be run without modification for all child classes.
"""
etl = self._setup_etl_instance_and_run_extract(
mock_etl=mock_etl, mock_paths=mock_paths
)
etl.transform()
# These tests work differently based on the ValidGeoLevel of the ETL class.
if etl.GEO_LEVEL == ValidGeoLevel.CENSUS_TRACT:
missing_tract_ids = np.setdiff1d(
self._FIXTURES_SHARED_TRACT_IDS,
etl.output_df[ExtractTransformLoad.GEOID_TRACT_FIELD_NAME],
)
if len(missing_tract_ids) > 0:
assert False, (
"Fixture data is missing the following necessary tract "
f"IDs: {missing_tract_ids}"
)
else:
raise NotImplementedError("This geo level not tested yet.")
def test_transform_sets_output_df_base(self, mock_etl, mock_paths):
"""This test ensures that the transform step sets its results to `output_df`.
Can be run without modification for all child classes.
"""
etl = self._setup_etl_instance_and_run_extract(
mock_etl=mock_etl, mock_paths=mock_paths
)
etl.transform()
assert etl.output_df is not None
# Assert it has some rows
assert etl.output_df.shape[0] > 0
# Check that it has all columns
for col in etl.COLUMNS_TO_KEEP:
assert col in etl.output_df.columns, f"{col} is missing from output"
def test_transform_base(self, mock_etl):
"""Tests the transform method.
Can be run without modification for all child classes.
"""
# setup - copy sample data into tmp_dir
etl = self._get_instance_of_etl_class()
etl.transform()
transform_csv_path = (
self._DATA_DIRECTORY_FOR_TEST / self._TRANSFORM_CSV_FILE_NAME
)
# Compare to expected.
expected = pd.read_csv(
filepath_or_buffer=transform_csv_path,
dtype={
ExtractTransformLoad.GEOID_TRACT_FIELD_NAME: "string",
ExtractTransformLoad.GEOID_FIELD_NAME: "string",
},
)
pd.testing.assert_frame_equal(etl.output_df, expected)
def test_load_base(self, mock_etl):
"""Test load method.
Can be run without modification for all child classes.
"""
# setup - input variables
etl = self._get_instance_of_etl_class()
# setup - mock transform step
df_transform = pd.read_csv(
self._DATA_DIRECTORY_FOR_TEST / self._TRANSFORM_CSV_FILE_NAME,
dtype={etl.GEOID_TRACT_FIELD_NAME: "string"},
)
etl.output_df = df_transform
# execution
etl.load()
# Make sure it creates the file.
actual_output_path = etl._get_output_file_path()
assert actual_output_path.exists()
actual_output = pd.read_csv(
actual_output_path, dtype={etl.GEOID_TRACT_FIELD_NAME: str}
)
expected_output_csv_path = (
self._DATA_DIRECTORY_FOR_TEST / self._OUTPUT_CSV_FILE_NAME
)
# setup - load expected output
expected_output = pd.read_csv(
filepath_or_buffer=expected_output_csv_path,
dtype={etl.GEOID_TRACT_FIELD_NAME: str},
)
# check that the `COLUMNS_TO_KEEP` are in the output
for col in etl.COLUMNS_TO_KEEP:
assert col in actual_output.columns, f"{col} is missing from output"
# validation
pd.testing.assert_frame_equal(actual_output, expected_output)
def test_validate_base(self, mock_etl, mock_paths):
"""Every ETL class should have proper validation.
Can be run without modification for all child classes.
"""
etl = self._setup_etl_instance_and_run_extract(
mock_etl=mock_etl, mock_paths=mock_paths
)
etl.transform()
# Transform is guaranteed to set a dataframe on etl.output_df.
# We can modify this data frame to test validation steps.
actual_output_df = etl.output_df
# These tests work differently based on the ValidGeoLevel of the ETL class.
if etl.GEO_LEVEL == ValidGeoLevel.CENSUS_TRACT:
# Remove geo field and make sure error occurs.
etl_without_geo_field = copy.deepcopy(etl)
columns_to_keep = [
column_to_keep
for column_to_keep in actual_output_df.columns
if column_to_keep != ExtractTransformLoad.GEOID_TRACT_FIELD_NAME
]
etl_without_geo_field.output_df = actual_output_df[columns_to_keep]
with pytest.raises(ValueError) as error:
etl_without_geo_field.validate()
assert str(error.value).startswith("Missing column:")
# Make sure too many rows throws error.
etl_with_too_many_rows = copy.deepcopy(etl)
etl_with_too_many_rows.EXPECTED_MAX_CENSUS_TRACTS = (
actual_output_df.shape[0] - 1
)
with pytest.raises(ValueError) as error:
etl_with_too_many_rows.validate()
assert str(error.value).startswith("Too many rows:")
# Make sure multiple geo field character length throws error.
etl_with_multiple_char_lengths = copy.deepcopy(etl)
etl_with_multiple_char_lengths.output_df = actual_output_df.copy(
deep=True
)
etl_with_multiple_char_lengths.output_df.loc[
0, ExtractTransformLoad.GEOID_TRACT_FIELD_NAME
] = "060070403001"
with pytest.raises(ValueError) as error:
etl_with_multiple_char_lengths.validate()
assert str(error.value).startswith("Multiple character lengths")
# Make sure wrong geo field character length throws error.
etl_with_wrong_geo_field_character_length = copy.deepcopy(etl)
etl_with_wrong_geo_field_character_length.output_df = (
actual_output_df.copy(deep=True)
)
etl_with_wrong_geo_field_character_length.output_df[
ExtractTransformLoad.GEOID_TRACT_FIELD_NAME
] = "060070403001"
with pytest.raises(ValueError) as error:
etl_with_wrong_geo_field_character_length.validate()
assert str(error.value).startswith("Wrong character length")
# Make duplicate tract IDs throws error.
etl_with_duplicate_geo_field = copy.deepcopy(etl)
etl_with_duplicate_geo_field.output_df = actual_output_df.copy(
deep=True
)
etl_with_duplicate_geo_field.output_df.loc[
0:1, ExtractTransformLoad.GEOID_TRACT_FIELD_NAME
] = "06007040300"
with pytest.raises(ValueError) as error:
etl_with_duplicate_geo_field.validate()
assert str(error.value).startswith("Duplicate values:")
elif etl.GEO_LEVEL == ValidGeoLevel.CENSUS_BLOCK_GROUP:
# Remove geo field and make sure error occurs.
etl_without_geo_field = copy.deepcopy(etl)
columns_to_keep = [
column_to_keep
for column_to_keep in actual_output_df.columns
if column_to_keep != ExtractTransformLoad.GEOID_FIELD_NAME
]
etl_without_geo_field.output_df = actual_output_df[columns_to_keep]
with pytest.raises(ValueError) as error:
etl_without_geo_field.validate()
assert str(error.value).startswith("Missing column:")
# Make sure too many rows throws error.
etl_with_too_many_rows = copy.deepcopy(etl)
etl_with_too_many_rows.EXPECTED_MAX_CENSUS_BLOCK_GROUPS = (
actual_output_df.shape[0] - 1
)
with pytest.raises(ValueError) as error:
etl_with_too_many_rows.validate()
assert str(error.value).startswith("Too many rows:")
# Make sure multiple geo field character length throws error.
etl_with_multiple_char_lengths = copy.deepcopy(etl)
etl_with_multiple_char_lengths.output_df = actual_output_df.copy(
deep=True
)
etl_with_multiple_char_lengths.output_df.loc[
0, ExtractTransformLoad.GEOID_FIELD_NAME
] = "06007040300123"
with pytest.raises(ValueError) as error:
etl_with_multiple_char_lengths.validate()
assert str(error.value).startswith("Multiple character lengths")
# Make sure wrong geo field character length throws error.
etl_with_wrong_geo_field_character_length = copy.deepcopy(etl)
etl_with_wrong_geo_field_character_length.output_df = (
actual_output_df.copy(deep=True)
)
etl_with_wrong_geo_field_character_length.output_df[
ExtractTransformLoad.GEOID_FIELD_NAME
] = "06007040300123"
with pytest.raises(ValueError) as error:
etl_with_wrong_geo_field_character_length.validate()
assert str(error.value).startswith("Wrong character length")
# Make duplicate block group IDs throws error.
etl_with_duplicate_geo_field = copy.deepcopy(etl)
etl_with_duplicate_geo_field.output_df = actual_output_df.copy(
deep=True
)
etl_with_duplicate_geo_field.output_df.loc[
0:1, ExtractTransformLoad.GEOID_FIELD_NAME
] = "0600704030012"
with pytest.raises(ValueError) as error:
etl_with_duplicate_geo_field.validate()
assert str(error.value).startswith("Duplicate values:")
else:
raise NotImplementedError("This geo level not tested yet.")
# Remove another column to keep and make sure error occurs.
etl_with_missing_column = copy.deepcopy(etl)
columns_to_keep = actual_output_df.columns[:-1]
etl_with_missing_column.output_df = actual_output_df[columns_to_keep]
with pytest.raises(ValueError) as error:
etl_with_missing_column.validate()
assert str(error.value).startswith("Missing column:")
# Test that validation on the original ETL works fine.
etl.validate()
def test_full_etl_base(self, mock_etl, mock_paths):
"""Every ETL class should be able to run end-to-end.
Run extract, transform, validate, load, and get without error.
Can be run without modification for all child classes.
"""
etl = self._setup_etl_instance_and_run_extract(
mock_etl=mock_etl, mock_paths=mock_paths
)
etl.transform()
etl.validate()
etl.load()
etl.get_data_frame()
def test_get_data_frame_base(self, mock_etl, mock_paths):
"""Every ETL class should be able to return its data frame.
Can be run without modification for all child classes.
"""
etl = self._setup_etl_instance_and_run_extract(
mock_etl=mock_etl, mock_paths=mock_paths
)
# TODO: look into moving this file deletion to a setup/teardown method that
# applies to all methods. I struggled to get that to work because I couldn't
# pass `mock_etl` and `mock_paths`
# Delete output file.
output_file_path = etl._get_output_file_path()
if os.path.exists(output_file_path):
logger.info("Deleting output file created by other tests.")
os.remove(output_file_path)
# Run more steps to generate test data.
etl.transform()
etl.validate()
# At this point, `get_data_frame` should error since file hasn't been written.
with pytest.raises(ValueError) as error:
etl.get_data_frame()
assert str(error.value).startswith("Make sure to run ETL")
# Run `load` step to write it to disk.
etl.load()
output_df = etl.get_data_frame()
# Check that all columns are present
for column_to_keep in etl.COLUMNS_TO_KEEP:
assert (
column_to_keep in output_df.columns
), f"Missing column: `{column_to_keep}` is missing from output"
# Make sure geo fields are read in as strings:
if etl.GEO_LEVEL == ValidGeoLevel.CENSUS_TRACT:
assert pd.api.types.is_string_dtype(
output_df[ExtractTransformLoad.GEOID_TRACT_FIELD_NAME]
)
elif etl.GEO_LEVEL == ValidGeoLevel.CENSUS_BLOCK_GROUP:
assert pd.api.types.is_string_dtype(
output_df[ExtractTransformLoad.GEOID_FIELD_NAME]
)
else:
raise NotImplementedError("This geo level not tested yet.")
# This decorator means that this "test" will only be run by passing that flag to
# pytest, for instance by running `pytest . -rsx --update_snapshots`.
@pytest.mark.update_snapshots
def test_update_test_fixtures(self, mock_etl, mock_paths):
"""Update the test fixtures (the data files) used by the test.
This needs to be reimplemented for every child class. This is because there
are not strict contracts on the outputs of the `extract` step so this method
needs to explicitly define how to update the `input` fixture that comes after
the extract step.
Using this method to update fixtures can be helpful if you expect the
results to change because you changed the logic of the ETL class and need to
quickly update the fixtures.
However, note a few things first:
1. Do *not* update these fixtures if you did not expect the ETL results to
change!
2. If the source data itself changes (e.g., the external source renames a
column), update the "furthest upstream" test fixture which, in many cases,
is a .zip file. Then running this method will update all subsequent files.
If you're confused by any of this, ask for help, it's confusing :).
"""
# When running this in child classes, make sure the child class re-implements
# this method.
if self._ETL_CLASS is not ExampleETL:
raise NotImplementedError(
"Update fixtures method not defined for this class."
)
# The rest of this method applies for `ExampleETL` only.
etl = self._setup_etl_instance_and_run_extract(
mock_etl=mock_etl, mock_paths=mock_paths
)
# After running extract, write the results as the "input.csv" in the test
# directory.
logger.info(
f"Writing data to {self._DATA_DIRECTORY_FOR_TEST / self._INPUT_CSV_FILE_NAME}"
)
copy_data_files(
src=etl.get_tmp_path() / "input.csv",
dst=self._DATA_DIRECTORY_FOR_TEST / self._INPUT_CSV_FILE_NAME,
)
# After running transform, write the results as the "transform.csv" in the test
# directory.
etl.transform()
etl.output_df.to_csv(
path_or_buf=self._DATA_DIRECTORY_FOR_TEST
/ self._TRANSFORM_CSV_FILE_NAME,
index=False,
)
# Run validate, just to check.
etl.validate()
# After running load, write the results as the "output.csv" in the test
# directory.
etl.load()
copy_data_files(
src=etl._get_output_file_path(),
dst=self._DATA_DIRECTORY_FOR_TEST / self._OUTPUT_CSV_FILE_NAME,
)