Issue 1075: Add refactored ETL tests to NRI (#1088)

* Adds a substantially refactored ETL test to the National Risk Index, to be used as a model for other tests
This commit is contained in:
Lucas Merrill Brown 2022-02-08 19:05:32 -05:00 committed by GitHub
commit 43e005cc10
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 1155 additions and 619 deletions

View file

@ -1,8 +1,9 @@
from pathlib import Path
import enum
import pathlib
import typing
from typing import Optional
import pandas as pd
import yaml
from data_pipeline.config import settings
from data_pipeline.utils import (
@ -14,6 +15,13 @@ from data_pipeline.utils import (
logger = get_module_logger(__name__)
class ValidGeoLevel(enum.Enum):
"""Enum used for indicating output data's geographic resolution."""
CENSUS_TRACT = enum.auto()
CENSUS_BLOCK_GROUP = enum.auto()
class ExtractTransformLoad:
"""
A class used to instantiate an ETL object to retrieve and process data from
@ -26,78 +34,74 @@ class ExtractTransformLoad:
GEOID_TRACT_FIELD_NAME (str): The common column name for a Census Tract identifier
"""
APP_ROOT: Path = settings.APP_ROOT
DATA_PATH: Path = APP_ROOT / "data"
TMP_PATH: Path = DATA_PATH / "tmp"
FILES_PATH: Path = settings.APP_ROOT / "files"
APP_ROOT: pathlib.Path = settings.APP_ROOT
# Directories
DATA_PATH: pathlib.Path = APP_ROOT / "data"
TMP_PATH: pathlib.Path = DATA_PATH / "tmp"
# Parameters
GEOID_FIELD_NAME: str = "GEOID10"
GEOID_TRACT_FIELD_NAME: str = "GEOID10_TRACT"
# Parameters that will be changed by children of the class
# NAME is used to create output path and populate logger info.
NAME: str = None
# LAST_UPDATED_YEAR is used to create output path.
LAST_UPDATED_YEAR: int = None
# SOURCE_URL is used to extract source data in extract().
SOURCE_URL: str = None
# GEO_LEVEL is used to identify whether output data is at the unit of the tract or
# census block group.
# TODO: add tests that enforce seeing the expected geographic identifier field
# in the output file based on this geography level.
GEO_LEVEL: ValidGeoLevel = None
# COLUMNS_TO_KEEP to used to identify which columns to keep in the output df.
COLUMNS_TO_KEEP: typing.List[str] = None
# Thirteen digits in a census block group ID.
EXPECTED_CENSUS_BLOCK_GROUPS_CHARACTER_LENGTH: int = 13
# TODO: investigate. Census says there are only 217,740 CBGs in the US. This might
# be from CBGs at different time periods.
EXPECTED_MAX_CENSUS_BLOCK_GROUPS: int = 250000
# Eleven digits in a census tract ID.
EXPECTED_CENSUS_TRACTS_CHARACTER_LENGTH: int = 11
# TODO: investigate. Census says there are only 74,134 tracts in the US,
# Puerto Rico, and island areas. This might be from tracts at different time
# periods. https://github.com/usds/justice40-tool/issues/964
EXPECTED_MAX_CENSUS_TRACTS: int = 74160
def __init__(self, config_path: Path) -> None:
"""Inits the class with instance specific variables"""
output_df: pd.DataFrame = None
# set by _get_yaml_config()
self.NAME: str = None
self.SOURCE_URL: str = None
self.GEOID_COL: str = None
self.GEO_LEVEL: str = None
self.SCORE_COLS: list = None
self.FIPS_CODES: pd.DataFrame = None
self.OUTPUT_PATH: Path = None
self.CENSUS_CSV: Path = None
# This is a classmethod so it can be used by `get_data_frame` without
# needing to create an instance of the class. This is a use case in `etl_score`.
@classmethod
def _get_output_file_path(cls) -> pathlib.Path:
"""Generate the output file path."""
if cls.NAME is None:
raise NotImplementedError(
f"Child ETL class needs to specify `cls.NAME` (currently "
f"{cls.NAME}) and `cls.LAST_UPDATED_YEAR` (currently "
f"{cls.LAST_UPDATED_YEAR})."
)
self._get_yaml_config(config_path)
def _get_yaml_config(self, config_path: Path) -> None:
"""Reads the YAML configuration file for the dataset and stores
the properies in the instance (upcoming feature)"""
# parse the yaml config file
try:
with open(config_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
except (FileNotFoundError, yaml.YAMLError) as err:
raise err
# set dataset specific attributes
census_dir = self.DATA_PATH / "census" / "csv"
if config["is_census"]:
csv_dir = census_dir
else:
self.CENSUS_CSV = census_dir / "us.csv"
self.FIPS_CODES = self._get_census_fips_codes()
csv_dir = self.DATA_PATH / "dataset"
# parse name and set output path
name = config.get("name")
snake_name = name.replace(" ", "_").lower() # converts to snake case
output_dir = snake_name + (config.get("year") or "")
self.OUTPUT_PATH = csv_dir / output_dir / "usa.csv"
self.OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
# set class attributes
attrs = ["NAME", "SOURCE_URL", "GEOID_COL", "GEO_LEVEL", "SCORE_COLS"]
for attr in attrs:
setattr(self, attr, config[attr.lower()])
def check_ttl(self) -> None:
"""Checks if the ETL process can be run based on a the TLL value on the
YAML config (upcoming feature)"""
pass
output_file_path = (
cls.DATA_PATH
/ "dataset"
/ f"{cls.NAME}_{cls.LAST_UPDATED_YEAR}"
/ "usa.csv"
)
return output_file_path
def extract(
self,
source_url: str = None,
extract_path: Path = None,
extract_path: pathlib.Path = None,
verify: Optional[bool] = True,
) -> None:
"""Extract the data from a remote source. By default it provides code
@ -107,7 +111,10 @@ class ExtractTransformLoad:
# this can be accessed via super().extract()
if source_url and extract_path:
unzip_file_from_url(
source_url, self.TMP_PATH, extract_path, verify=verify
file_url=source_url,
download_path=self.TMP_PATH,
unzipped_file_path=extract_path,
verify=verify,
)
def transform(self) -> None:
@ -116,63 +123,146 @@ class ExtractTransformLoad:
raise NotImplementedError
def load(self) -> None:
"""Saves the transformed data in the specified local data folder or remote AWS S3
bucket"""
def validate(self) -> None:
"""Validates the output.
raise NotImplementedError
def cleanup(self) -> None:
"""Clears out any files stored in the TMP folder"""
remove_all_from_dir(self.TMP_PATH)
# TODO: Add test for this
def _get_census_fips_codes(self) -> pd.DataFrame:
"""Loads FIPS codes for each Census block group and tract"""
# check that the census data exists
if not self.CENSUS_CSV.exists():
logger.info("Census data not found, please run download_csv first")
# load the census data
df = pd.read_csv(
self.CENSUS_CSV, dtype={self.GEOID_FIELD_NAME: "string"}
)
# extract Census tract FIPS code from Census block group
df[self.GEOID_TRACT_FIELD_NAME] = df[self.GEOID_FIELD_NAME].str[0:11]
return df[[self.GEOID_FIELD_NAME, self.GEOID_TRACT_FIELD_NAME]]
# TODO: Create tests
def validate_output(self) -> None:
"""Checks that the output of the ETL process adheres to the contract
expected by the score module
Contract conditions:
- Output is saved as usa.csv at the path specified by self.OUTPUT_PATH
- The output csv has a column named GEOID10 which stores each of the
Census block group FIPS codes in data/census/csv/usa.csv
- The output csv has a column named GEOID10_TRACT which stores each of
Census tract FIPS codes associated with each Census block group
- The output csv has each of the columns expected by the score and the
name and dtype of those columns match the format expected by score
Runs after the `transform` step and before `load`.
"""
# read in output file
# and check that GEOID cols are present
assert self.OUTPUT_PATH.exists(), f"No file found at {self.OUTPUT_PATH}"
df_output = pd.read_csv(
self.OUTPUT_PATH,
# TODO: remove this once all ETL classes are converted to using the new
# base class parameters and patterns.
if self.GEO_LEVEL is None:
logger.info(
"Skipping validation step for this class because it does not "
"seem to be converted to new ETL class patterns."
)
return
if self.COLUMNS_TO_KEEP is None:
raise NotImplementedError(
"`self.COLUMNS_TO_KEEP` must be specified."
)
if self.output_df is None:
raise NotImplementedError(
"The `transform` step must set `self.output_df`."
)
for column_to_keep in self.COLUMNS_TO_KEEP:
if column_to_keep not in self.output_df.columns:
raise ValueError(
f"Missing column: `{column_to_keep}` is missing from "
f"output"
)
for (
geo_level,
geo_field,
expected_geo_field_characters,
expected_rows,
) in [
(
ValidGeoLevel.CENSUS_TRACT,
self.GEOID_TRACT_FIELD_NAME,
self.EXPECTED_CENSUS_TRACTS_CHARACTER_LENGTH,
self.EXPECTED_MAX_CENSUS_TRACTS,
),
(
ValidGeoLevel.CENSUS_BLOCK_GROUP,
self.GEOID_FIELD_NAME,
self.EXPECTED_CENSUS_BLOCK_GROUPS_CHARACTER_LENGTH,
self.EXPECTED_MAX_CENSUS_BLOCK_GROUPS,
),
]:
if self.GEO_LEVEL is geo_level:
if geo_field not in self.COLUMNS_TO_KEEP:
raise ValueError(
f"Must have `{geo_field}` in columns if "
f"specifying geo level as `{geo_level} "
)
if self.output_df.shape[0] > expected_rows:
raise ValueError(
f"Too many rows: `{self.output_df.shape[0]}` rows in "
f"output exceeds expectation of `{expected_rows}` "
f"rows."
)
if self.output_df[geo_field].str.len().nunique() > 1:
raise ValueError(
f"Multiple character lengths for geo field "
f"present: {self.output_df[geo_field].str.len().unique()}."
)
elif (
len(self.output_df[geo_field].array[0])
!= expected_geo_field_characters
):
raise ValueError(
"Wrong character length: the census geography data "
"has the wrong length."
)
duplicate_geo_field_values = (
self.output_df[geo_field].shape[0]
- self.output_df[geo_field].nunique()
)
if duplicate_geo_field_values > 0:
raise ValueError(
f"Duplicate values: There are {duplicate_geo_field_values} "
f"duplicate values in "
f"`{geo_field}`."
)
def load(self, float_format=None) -> None:
"""Saves the transformed data.
Data is written in the specified local data folder or remote AWS S3 bucket.
Uses the directory from `self.OUTPUT_DIR` and the file name from
`self._get_output_file_path`.
"""
logger.info(f"Saving `{self.NAME}` CSV")
# Create directory if necessary.
output_file_path = self._get_output_file_path()
output_file_path.parent.mkdir(parents=True, exist_ok=True)
# Write nationwide csv
self.output_df[self.COLUMNS_TO_KEEP].to_csv(
output_file_path, index=False, float_format=float_format
)
logger.info(f"File written to `{output_file_path}`.")
# This is a classmethod so it can be used without needing to create an instance of
# the class. This is a use case in `etl_score`.
@classmethod
def get_data_frame(cls) -> pd.DataFrame:
"""Return the output data frame for this class.
Must be run after a full ETL process has been run for this class.
If the ETL has been not run for this class, this will error.
"""
# Read in output file
output_file_path = cls._get_output_file_path()
if not output_file_path.exists():
raise ValueError(
f"Make sure to run ETL process first for `{cls}`. "
f"No file found at `{output_file_path}`."
)
output_df = pd.read_csv(
output_file_path,
dtype={
self.GEOID_FIELD_NAME: "string",
self.GEOID_TRACT_FIELD_NAME: "string",
# Not all outputs will have both a Census Block Group ID and a
# Tract ID, but these will be ignored if they're not present.
cls.GEOID_FIELD_NAME: "string",
cls.GEOID_TRACT_FIELD_NAME: "string",
},
)
# check that the GEOID cols in the output match census data
geoid_cols = [self.GEOID_FIELD_NAME, self.GEOID_TRACT_FIELD_NAME]
for col in geoid_cols:
assert col in self.FIPS_CODES.columns
assert self.FIPS_CODES.equals(df_output[geoid_cols])
return output_df
# check that the score columns are in the output
for col in self.SCORE_COLS:
assert col in df_output.columns, f"{col} is missing from output"
def cleanup(self) -> None:
"""Clears out any files stored in the TMP folder"""
remove_all_from_dir(self.TMP_PATH)