j40-cejst-2/data/data-pipeline/data_pipeline/etl/base.py
Travis Newby 6f39033dde
Add ability to cache ETL data sources (#2169)
* Add a rough prototype allowing a developer to pre-download data sources for all ETLs

* Update code to be more production-ish

* Move fetch to Extract part of ETL
* Create a downloader to house all downloading operations
* Remove unnecessary "name" in data source

* Format source files with black

* Fix issues from pylint and get the tests working with the new folder structure

* Clean up files with black

* Fix unzip test

* Add caching notes to README

* Fix tests (linting and case sensitivity bug)

* Address PR comments and add API keys for census where missing

* Merging comparator changes from main into this branch for the sake of the PR

* Add note on using cache (-u) during pipeline
2023-03-03 12:26:24 -06:00

411 lines
15 KiB
Python

import enum
import pathlib
import sys
import typing
import shutil
from typing import Optional
from abc import ABC, abstractmethod
import pandas as pd
from data_pipeline.config import settings
from data_pipeline.etl.score.etl_utils import (
compare_to_list_of_expected_state_fips_codes,
)
from data_pipeline.etl.score.schemas.datasets import DatasetsConfig
from data_pipeline.utils import get_module_logger
from data_pipeline.utils import load_yaml_dict_from_file
from data_pipeline.utils import remove_all_from_dir
from data_pipeline.etl.datasource import DataSource
logger = get_module_logger(__name__)
class ValidGeoLevel(enum.Enum):
"""Enum used for indicating output data's geographic resolution."""
CENSUS_TRACT = enum.auto()
CENSUS_BLOCK_GROUP = enum.auto()
class ExtractTransformLoad(ABC):
"""
A class used to instantiate an ETL object to retrieve and process data from
datasets.
Attributes:
DATA_PATH (pathlib.Path): Local path where all data will be stored
TMP_PATH (pathlib.Path): Local path where temporary data will be stored
TODO: Fill missing attrs here
GEOID_FIELD_NAME (str): The common column name for a Census Block Group identifier
GEOID_TRACT_FIELD_NAME (str): The common column name for a Census Tract identifier
"""
APP_ROOT: pathlib.Path = settings.APP_ROOT
# Directories
DATA_PATH: pathlib.Path = settings.DATA_PATH
TMP_PATH: pathlib.Path = DATA_PATH / "tmp"
SOURCES_PATH: pathlib.Path = DATA_PATH / "sources"
CONTENT_CONFIG: pathlib.Path = APP_ROOT / "content" / "config"
DATASET_CONFIG_PATH: pathlib.Path = APP_ROOT / "etl" / "score" / "config"
DATASET_CONFIG: Optional[dict] = None
# Parameters
GEOID_FIELD_NAME: str = "GEOID10"
GEOID_TRACT_FIELD_NAME: str = "GEOID10_TRACT"
# Parameters that will be changed by children of the class
# NAME is used to create output path and populate logger info.
NAME: str = None
# LAST_UPDATED_YEAR is used to create output path.
LAST_UPDATED_YEAR: int = None
# SOURCE_URL is used to extract source data in extract().
SOURCE_URL: str = None
# INPUT_EXTRACTED_FILE_NAME is the name of the file after extract().
INPUT_EXTRACTED_FILE_NAME: str = None
# GEO_LEVEL is used to identify whether output data is at the unit of the tract or
# census block group.
# TODO: add tests that enforce seeing the expected geographic identifier field
# in the output file based on this geography level.
GEO_LEVEL: ValidGeoLevel = None
# COLUMNS_TO_KEEP is used to identify which columns to keep in the output df.
COLUMNS_TO_KEEP: typing.List[str] = None
# INPUT_GEOID_TRACT_FIELD_NAME is the field name that identifies the Census Tract ID
# on the input file
INPUT_GEOID_TRACT_FIELD_NAME: str = None
# NULL_REPRESENTATION is how nulls are represented on the input field
NULL_REPRESENTATION: str = None
# Whether this ETL contains data for the continental nation (DC & the US states
# except for Alaska and Hawaii)
CONTINENTAL_US_EXPECTED_IN_DATA: bool = True
# Whether this ETL contains data for Alaska and Hawaii
ALASKA_AND_HAWAII_EXPECTED_IN_DATA: bool = True
# Whether this ETL contains data for Puerto Rico
PUERTO_RICO_EXPECTED_IN_DATA: bool = True
# Whether this ETL contains data for the island areas
ISLAND_AREAS_EXPECTED_IN_DATA: bool = False
# Whether this ETL contains known missing data for any additional
# states/territories
EXPECTED_MISSING_STATES: typing.List[str] = []
# Thirteen digits in a census block group ID.
EXPECTED_CENSUS_BLOCK_GROUPS_CHARACTER_LENGTH: int = 13
# TODO: investigate. Census says there are only 217,740 CBGs in the US. This might
# be from CBGs at different time periods.
EXPECTED_MAX_CENSUS_BLOCK_GROUPS: int = 250000
# There should be Eleven digits in a census tract ID.
EXPECTED_CENSUS_TRACTS_CHARACTER_LENGTH: int = 11
# TODO: investigate. Census says there are only 74,134 tracts in the United States,
# Puerto Rico, and island areas. This might be from tracts at different time
# periods. https://github.com/usds/justice40-tool/issues/964
EXPECTED_MAX_CENSUS_TRACTS: int = 74160
# Should this dataset load its configuration from
# the YAML files?
LOAD_YAML_CONFIG: bool = False
# We use output_df as the final dataframe to use to write to the CSV
# It is used on the "load" base class method
output_df: pd.DataFrame = None
def __init_subclass__(cls) -> None:
if cls.LOAD_YAML_CONFIG:
cls.DATASET_CONFIG = cls.yaml_config_load()
@classmethod
def yaml_config_load(cls) -> dict:
"""Generate config dictionary and set instance variables from YAML dataset."""
# check if the class instance has score YAML definitions
datasets_config = load_yaml_dict_from_file(
cls.DATASET_CONFIG_PATH / "datasets.yml",
DatasetsConfig,
)
# get the config for this dataset
try:
dataset_config = next(
item
for item in datasets_config.get("datasets")
if item["module_name"] == cls.NAME
)
except StopIteration:
# Note: it'd be nice to log the name of the dataframe, but that's not accessible in this scope.
logger.error(
f"Exception encountered while extracting dataset config for dataset {cls.NAME}"
)
sys.exit()
# set some of the basic fields
if "input_geoid_tract_field_name" in dataset_config:
cls.INPUT_GEOID_TRACT_FIELD_NAME = dataset_config[
"input_geoid_tract_field_name"
]
# get the columns to write on the CSV
# and set the constants
cls.COLUMNS_TO_KEEP = [
cls.GEOID_TRACT_FIELD_NAME, # always index with geoid tract id
]
for field in dataset_config["load_fields"]:
cls.COLUMNS_TO_KEEP.append(field["long_name"])
setattr(cls, field["df_field_name"], field["long_name"])
return dataset_config
# This is a classmethod so it can be used by `get_data_frame` without
# needing to create an instance of the class. This is a use case in `etl_score`.
@classmethod
def _get_output_file_path(cls) -> pathlib.Path:
"""Generate the output file path."""
if cls.NAME is None:
raise NotImplementedError(
f"Child ETL class needs to specify `cls.NAME` (currently "
f"{cls.NAME})."
)
output_file_path = cls.DATA_PATH / "dataset" / f"{cls.NAME}" / "usa.csv"
return output_file_path
def get_sources_path(self) -> pathlib.Path:
"""Returns the sources path associated with this ETL class. The sources path
is the home for cached data sources used by this ETL."""
sources_path = self.SOURCES_PATH / str(self.__class__.__name__)
# Create directory if it doesn't exist
sources_path.mkdir(parents=True, exist_ok=True)
return sources_path
@abstractmethod
def get_data_sources(self) -> [DataSource]:
pass
def _fetch(self) -> None:
"""Fetch all data sources for this ETL. When data sources are fetched, they
are stored in a cache directory for consistency between runs."""
for ds in self.get_data_sources():
ds.fetch()
def clear_data_source_cache(self) -> None:
"""Clears the cache for this ETLs data source(s)"""
shutil.rmtree(self.get_sources_path())
def extract(self, use_cached_data_sources: bool = False) -> None:
"""Extract (download) data from a remote source, and validate
that data. By default, this method fetches data from the set of
data sources returned by get_data_sources.
If use_cached_data_sources is true, this method attempts to use cached data
rather than re-downloading from the original source. The cache algorithm is very
simple: it just looks to see if the directory has any contents. If so, it uses
that content. If not, it downloads all data sources.
Subclasses should call super() before performing any work if they wish to take
advantage of the automatic downloading and caching ability of this superclass.
"""
if use_cached_data_sources and any(self.get_sources_path().iterdir()):
logger.info(
f"Using cached data sources for {self.__class__.__name__}"
)
else:
self.clear_data_source_cache()
self._fetch()
# the rest of the work should be performed here
@abstractmethod
def transform(self) -> None:
"""Transform the data extracted into a format that can be consumed by the
score generator"""
pass
def validate(self) -> None:
"""Validates the output.
Runs after the `transform` step and before `load`.
"""
# TODO: remove this once all ETL classes are converted to using the new
# base class parameters and patterns.
if self.GEO_LEVEL is None:
logger.warning(
f"Skipping validation step for {self.__class__.__name__} because it does not "
"seem to be converted to new ETL class patterns."
)
return
if self.COLUMNS_TO_KEEP is None:
raise NotImplementedError(
"`self.COLUMNS_TO_KEEP` must be specified."
)
if self.output_df is None:
raise NotImplementedError(
"The `transform` step must set `self.output_df`."
)
for column_to_keep in self.COLUMNS_TO_KEEP:
if column_to_keep not in self.output_df.columns:
raise ValueError(
f"Missing column: `{column_to_keep}` is missing from "
f"output"
)
for (
geo_level,
geo_field,
expected_geo_field_characters,
expected_rows,
) in [
(
ValidGeoLevel.CENSUS_TRACT,
self.GEOID_TRACT_FIELD_NAME,
self.EXPECTED_CENSUS_TRACTS_CHARACTER_LENGTH,
self.EXPECTED_MAX_CENSUS_TRACTS,
),
(
ValidGeoLevel.CENSUS_BLOCK_GROUP,
self.GEOID_FIELD_NAME,
self.EXPECTED_CENSUS_BLOCK_GROUPS_CHARACTER_LENGTH,
self.EXPECTED_MAX_CENSUS_BLOCK_GROUPS,
),
]:
if self.GEO_LEVEL is geo_level:
if geo_field not in self.COLUMNS_TO_KEEP:
raise ValueError(
f"Must have `{geo_field}` in columns if "
f"specifying geo level as `{geo_level} "
)
if self.output_df.shape[0] > expected_rows:
raise ValueError(
f"Too many rows: `{self.output_df.shape[0]}` rows in "
f"output exceeds expectation of `{expected_rows}` "
f"rows."
)
if self.output_df[geo_field].str.len().nunique() > 1:
raise ValueError(
f"Multiple character lengths for geo field "
f"present: {self.output_df[geo_field].str.len().unique()}."
)
elif (
len(self.output_df[geo_field].array[0])
!= expected_geo_field_characters
):
raise ValueError(
"Wrong character length: the census geography data "
"has the wrong length."
)
duplicate_geo_field_values = (
self.output_df[geo_field].shape[0]
- self.output_df[geo_field].nunique()
)
if duplicate_geo_field_values > 0:
raise ValueError(
f"Duplicate values: There are {duplicate_geo_field_values} "
f"duplicate values in "
f"`{geo_field}`."
)
# Check whether data contains expected states
states_in_output_df = (
self.output_df[self.GEOID_TRACT_FIELD_NAME]
.str[0:2]
.unique()
.tolist()
)
compare_to_list_of_expected_state_fips_codes(
actual_state_fips_codes=states_in_output_df,
continental_us_expected=self.CONTINENTAL_US_EXPECTED_IN_DATA,
alaska_and_hawaii_expected=self.ALASKA_AND_HAWAII_EXPECTED_IN_DATA,
puerto_rico_expected=self.PUERTO_RICO_EXPECTED_IN_DATA,
island_areas_expected=self.ISLAND_AREAS_EXPECTED_IN_DATA,
additional_fips_codes_not_expected=self.EXPECTED_MISSING_STATES,
dataset_name=self.NAME,
)
def load(self, float_format=None) -> None:
"""Saves the transformed data.
Data is written in the specified local data folder or remote AWS S3 bucket.
Uses the directory and the file name from `self._get_output_file_path`.
"""
logger.debug(f"Saving `{self.NAME}` CSV")
# Create directory if necessary.
output_file_path = self._get_output_file_path()
output_file_path.parent.mkdir(parents=True, exist_ok=True)
# Write nationwide csv
self.output_df[self.COLUMNS_TO_KEEP].to_csv(
output_file_path, index=False, float_format=float_format
)
logger.debug(f"File written to `{output_file_path}`.")
# This is a classmethod so it can be used without needing to create an instance of
# the class. This is a use case in `etl_score`.
@classmethod
def get_data_frame(cls) -> pd.DataFrame:
"""Return the output data frame for this class.
Must be run after a full ETL process has been run for this class.
If the ETL has been not run for this class, this will error.
"""
# Read in output file
output_file_path = cls._get_output_file_path()
if not output_file_path.exists():
raise ValueError(
f"Make sure to run ETL process first for `{cls}`. "
f"No file found at `{output_file_path}`."
)
logger.debug(
f"Reading in CSV `{output_file_path}` for ETL of class `{cls}`."
)
output_df = pd.read_csv(
output_file_path,
dtype={
# Not all outputs will have both a Census Block Group ID and a
# Tract ID, but these will be ignored if they're not present.
cls.GEOID_FIELD_NAME: "string",
cls.GEOID_TRACT_FIELD_NAME: "string",
},
)
return output_df
def cleanup(self) -> None:
"""Clears out any files stored in the TMP folder"""
remove_all_from_dir(self.get_tmp_path())
def get_tmp_path(self) -> pathlib.Path:
"""Returns the temporary path associated with this ETL class."""
# Note: the temporary path will be defined on `init`, because it uses the class
# of the instance which is often a child class.
tmp_path = self.DATA_PATH / "tmp" / str(self.__class__.__name__)
# Create directory if it doesn't exist
tmp_path.mkdir(parents=True, exist_ok=True)
return tmp_path