From ece3ded2eb2c6560961e22991b26abc24f75d19e Mon Sep 17 00:00:00 2001 From: Wenyi Kuang Date: Sun, 10 Mar 2024 20:51:17 -0600 Subject: [PATCH] Better way to handle read_csv function's mock --- postprocessing/comstockpostproc/cbecs.py | 6 ++--- postprocessing/comstockpostproc/comstock.py | 6 ++--- postprocessing/test/utility/mock_CBECS.py | 25 +++++++++++--------- postprocessing/test/utility/mock_comstock.py | 21 +++++++++------- 4 files changed, 30 insertions(+), 28 deletions(-) diff --git a/postprocessing/comstockpostproc/cbecs.py b/postprocessing/comstockpostproc/cbecs.py index be2dc8c48..591d7d299 100644 --- a/postprocessing/comstockpostproc/cbecs.py +++ b/postprocessing/comstockpostproc/cbecs.py @@ -97,19 +97,17 @@ def download_data(self): s3_file_path = f'truth_data/{self.truth_data_version}/EIA/CBECS/{file_name}' self.read_delimited_truth_data_file_from_S3(s3_file_path, ',') - def _read_csv(self, file_path, low_memory, na_values, index_col=None): - return pd.read_csv(file_path, low_memory=low_memory, na_values=na_values, index_col=index_col) def load_data(self): # Load raw microdata and codebook and decode numeric keys to strings using codebook # Load microdata file_path = os.path.join(self.truth_data_dir, self.data_file_name) - self.data = self._read_csv(file_path=file_path, low_memory=False, na_values=['.']) + self.data = pd.read_csv(file_path, low_memory=False, na_values=['.']) # Load microdata codebook file_path = os.path.join(self.truth_data_dir, self.data_codebook_file_name) - codebook = self._read_csv(file_path=file_path, index_col='File order', low_memory=False) + codebook = pd.read_csv(file_path, index_col='File order', low_memory=False) # Make a dict of column names (e.g. PBA) to labels (e.g. Principal building activity) # and a dict of numeric enumerations to strings for non-numeric variables var_name_to_label = {} diff --git a/postprocessing/comstockpostproc/comstock.py b/postprocessing/comstockpostproc/comstock.py index 5e804f9fe..8418e44d4 100644 --- a/postprocessing/comstockpostproc/comstock.py +++ b/postprocessing/comstockpostproc/comstock.py @@ -728,8 +728,6 @@ def add_geospatial_columns(self): # Show the dataset size logger.debug(f'Memory after add_geospatial_columns: {self.data.estimated_size()}') - def _read_csv(self, path, columns, dtypes): return pl.read_csv(path, columns=columns, dtypes=dtypes) - def add_ejscreen_columns(self): # Add the EJ Screen data if not 'nhgis_tract_gisjoin' in self.data: @@ -749,7 +747,7 @@ def add_ejscreen_columns(self): # Read the buildstock.csv and join columns onto annual results by building ID file_name = 'EJSCREEN_Tract_2020_USPR.csv' file_path = os.path.join(self.truth_data_dir, file_name) - ejscreen = self._read_csv(path=file_path, columns=col_def_names, dtypes={'ID': str}) + ejscreen = pl.read_csv(file_path, columns=col_def_names, dtypes={'ID': str}) # Convert EJSCREEN census tract ID to gisjoin format @lru_cache() @@ -809,7 +807,7 @@ def add_cejst_columns(self): # Read the buildstock.csv and join columns onto annual results by building ID file_name = self.cejst_file_name file_path = os.path.join(self.truth_data_dir, file_name) - cejst = self._read_csv(path=file_path, columns=col_def_names, dtypes=col_def_types) + cejst = pl.read_csv(file_path, columns=col_def_names, dtypes=col_def_types) # Convert CEJST census tract ID to gisjoin format @lru_cache() diff --git a/postprocessing/test/utility/mock_CBECS.py b/postprocessing/test/utility/mock_CBECS.py index 9226f8243..e9423a9b8 100644 --- a/postprocessing/test/utility/mock_CBECS.py +++ b/postprocessing/test/utility/mock_CBECS.py @@ -18,24 +18,27 @@ def __init__(self): self.mock_read_delimited_truth_data_file_from_S3 = self.patcher_read_delimited_truth_data_file_from_S3.start() self.mock_read_delimited_truth_data_file_from_S3.side_effect = self.mock_read_delimited_truth_data_file_from_S3_action - self.patcher__read_csv = patch('comstockpostproc.cbecs.CBECS._read_csv') + self.original_read_csv = pd.read_csv + self.patcher__read_csv = patch('pandas.read_csv') self.mock__read_csv = self.patcher__read_csv.start() self.mock__read_csv.side_effect = self.mock__read_csv_action - def mock_read_delimited_truth_data_file_from_S3_action(self, s3_file_path, delimiter): - logging.info('reading from path: {} with delimiter {}'.format(s3_file_path, delimiter)) - return pd.DataFrame() - - def mock__read_csv_action(self, **kwargs): + def mock__read_csv_action(self, *args ,**kwargs): logging.info('Mocking read_csv from CBECS') - path = kwargs["file_path"] + path = args[0] + filePath = None if "CBECS_2018_microdata.csv" in path: - filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata.csv" + filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata.csv" elif "CBECS_2018_microdata_codebook.csv" in path: - filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata_codebook.csv" + filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata_codebook.csv" - del kwargs["file_path"] - return pd.read_csv(filePath, **kwargs) + if filePath is None: + return self.original_read_csv(*args, **kwargs) + return self.original_read_csv(filePath, **kwargs) + + def mock_read_delimited_truth_data_file_from_S3_action(self, s3_file_path, delimiter): + logging.info('reading from path: {} with delimiter {}'.format(s3_file_path, delimiter)) + return pd.DataFrame() def stop(self): self.patcher.stop() diff --git a/postprocessing/test/utility/mock_comstock.py b/postprocessing/test/utility/mock_comstock.py index ccf0a11a2..ff436ac6c 100644 --- a/postprocessing/test/utility/mock_comstock.py +++ b/postprocessing/test/utility/mock_comstock.py @@ -28,8 +28,9 @@ def __init__(self): self.patcher_read_delimited_truth_data_file_from_S3 = patch('comstockpostproc.comstock.ComStock.read_delimited_truth_data_file_from_S3') self.mock_read_delimited_truth_data_file_from_S3 = self.patcher_read_delimited_truth_data_file_from_S3.start() self.mock_read_delimited_truth_data_file_from_S3.side_effect = self.mock_read_delimited_truth_data_file_from_S3_action - - self.patcher__read_csv = patch('comstockpostproc.comstock.ComStock._read_csv') + + self.original_read_csv = pl.read_csv + self.patcher__read_csv = patch('polars.read_csv') self.mock__read_csv = self.patcher__read_csv.start() self.mock__read_csv.side_effect = self.mock__read_csv_action @@ -44,15 +45,17 @@ def mock_isfile_on_S3_action(self, bucket, file_path): logging.info('Mocking isfile_on_S3') return True - def mock__read_csv_action(self, **kwargs): - path = kwargs["path"] + def mock__read_csv_action(self, *args, **kwargs): + logging.info('Mocking read_csv from ComStock') + filePath = None + path = args[0] if "EJSCREEN" in path: - filePath = "/truth_data/v01/EPA/EJSCREEN/EJSCREEN_Tract_2020_USPR.csv" + filePath = "/truth_data/v01/EPA/EJSCREEN/EJSCREEN_Tract_2020_USPR.csv" elif "1.0-communities.csv" in path: - filePath = "/truth_data/v01/EPA/CEJST/1.0-communities.csv" - del kwargs["path"] - - return pl.read_csv(filePath, **kwargs) + filePath = "/truth_data/v01/EPA/CEJST/1.0-communities.csv" + if not filePath: + return self.original_read_csv(*args, **kwargs) + return self.original_read_csv(filePath, **kwargs) def stop(self): self.patcher.stop()