Add ability to cache ETL data sources (#2169)

* Add a rough prototype allowing a developer to pre-download data sources for all ETLs

* Update code to be more production-ish

* Move fetch to Extract part of ETL
* Create a downloader to house all downloading operations
* Remove unnecessary "name" in data source

* Format source files with black

* Fix issues from pylint and get the tests working with the new folder structure

* Clean up files with black

* Fix unzip test

* Add caching notes to README

* Fix tests (linting and case sensitivity bug)

* Address PR comments and add API keys for census where missing

* Merging comparator changes from main into this branch for the sake of the PR

* Add note on using cache (-u) during pipeline
This commit is contained in:
Travis Newby 2023-03-03 12:26:24 -06:00 committed by GitHub
commit 6f39033dde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 1787 additions and 686 deletions

View file

@ -2,7 +2,9 @@ import enum
import pathlib
import sys
import typing
import shutil
from typing import Optional
from abc import ABC, abstractmethod
import pandas as pd
from data_pipeline.config import settings
@ -13,7 +15,7 @@ from data_pipeline.etl.score.schemas.datasets import DatasetsConfig
from data_pipeline.utils import get_module_logger
from data_pipeline.utils import load_yaml_dict_from_file
from data_pipeline.utils import remove_all_from_dir
from data_pipeline.utils import unzip_file_from_url
from data_pipeline.etl.datasource import DataSource
logger = get_module_logger(__name__)
@ -25,7 +27,7 @@ class ValidGeoLevel(enum.Enum):
CENSUS_BLOCK_GROUP = enum.auto()
class ExtractTransformLoad:
class ExtractTransformLoad(ABC):
"""
A class used to instantiate an ETL object to retrieve and process data from
datasets.
@ -45,6 +47,7 @@ class ExtractTransformLoad:
# Directories
DATA_PATH: pathlib.Path = settings.DATA_PATH
TMP_PATH: pathlib.Path = DATA_PATH / "tmp"
SOURCES_PATH: pathlib.Path = DATA_PATH / "sources"
CONTENT_CONFIG: pathlib.Path = APP_ROOT / "content" / "config"
DATASET_CONFIG_PATH: pathlib.Path = APP_ROOT / "etl" / "score" / "config"
DATASET_CONFIG: Optional[dict] = None
@ -177,45 +180,60 @@ class ExtractTransformLoad:
output_file_path = cls.DATA_PATH / "dataset" / f"{cls.NAME}" / "usa.csv"
return output_file_path
def get_tmp_path(self) -> pathlib.Path:
"""Returns the temporary path associated with this ETL class."""
# Note: the temporary path will be defined on `init`, because it uses the class
# of the instance which is often a child class.
tmp_path = self.DATA_PATH / "tmp" / str(self.__class__.__name__)
def get_sources_path(self) -> pathlib.Path:
"""Returns the sources path associated with this ETL class. The sources path
is the home for cached data sources used by this ETL."""
sources_path = self.SOURCES_PATH / str(self.__class__.__name__)
# Create directory if it doesn't exist
tmp_path.mkdir(parents=True, exist_ok=True)
sources_path.mkdir(parents=True, exist_ok=True)
return tmp_path
return sources_path
def extract(
self,
source_url: str = None,
extract_path: pathlib.Path = None,
verify: Optional[bool] = True,
) -> None:
"""Extract the data from a remote source. By default it provides code
to get the file from a source url, unzips it and stores it on an
extract_path."""
@abstractmethod
def get_data_sources(self) -> [DataSource]:
pass
if source_url is None:
source_url = self.SOURCE_URL
def _fetch(self) -> None:
"""Fetch all data sources for this ETL. When data sources are fetched, they
are stored in a cache directory for consistency between runs."""
for ds in self.get_data_sources():
ds.fetch()
if extract_path is None:
extract_path = self.get_tmp_path()
def clear_data_source_cache(self) -> None:
"""Clears the cache for this ETLs data source(s)"""
shutil.rmtree(self.get_sources_path())
unzip_file_from_url(
file_url=source_url,
download_path=self.get_tmp_path(),
unzipped_file_path=extract_path,
verify=verify,
)
def extract(self, use_cached_data_sources: bool = False) -> None:
"""Extract (download) data from a remote source, and validate
that data. By default, this method fetches data from the set of
data sources returned by get_data_sources.
If use_cached_data_sources is true, this method attempts to use cached data
rather than re-downloading from the original source. The cache algorithm is very
simple: it just looks to see if the directory has any contents. If so, it uses
that content. If not, it downloads all data sources.
Subclasses should call super() before performing any work if they wish to take
advantage of the automatic downloading and caching ability of this superclass.
"""
if use_cached_data_sources and any(self.get_sources_path().iterdir()):
logger.info(
f"Using cached data sources for {self.__class__.__name__}"
)
else:
self.clear_data_source_cache()
self._fetch()
# the rest of the work should be performed here
@abstractmethod
def transform(self) -> None:
"""Transform the data extracted into a format that can be consumed by the
score generator"""
raise NotImplementedError
pass
def validate(self) -> None:
"""Validates the output.
@ -380,3 +398,14 @@ class ExtractTransformLoad:
def cleanup(self) -> None:
"""Clears out any files stored in the TMP folder"""
remove_all_from_dir(self.get_tmp_path())
def get_tmp_path(self) -> pathlib.Path:
"""Returns the temporary path associated with this ETL class."""
# Note: the temporary path will be defined on `init`, because it uses the class
# of the instance which is often a child class.
tmp_path = self.DATA_PATH / "tmp" / str(self.__class__.__name__)
# Create directory if it doesn't exist
tmp_path.mkdir(parents=True, exist_ok=True)
return tmp_path