Skip to content

Commit

Permalink
Better way to handle read_csv function's mock
Browse files Browse the repository at this point in the history
  • Loading branch information
wenyikuang committed Mar 11, 2024
1 parent 8ae96bf commit ece3ded
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 28 deletions.
6 changes: 2 additions & 4 deletions postprocessing/comstockpostproc/cbecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,17 @@ def download_data(self):
s3_file_path = f'truth_data/{self.truth_data_version}/EIA/CBECS/{file_name}'
self.read_delimited_truth_data_file_from_S3(s3_file_path, ',')

def _read_csv(self, file_path, low_memory, na_values, index_col=None):
return pd.read_csv(file_path, low_memory=low_memory, na_values=na_values, index_col=index_col)

def load_data(self):
# Load raw microdata and codebook and decode numeric keys to strings using codebook

# Load microdata
file_path = os.path.join(self.truth_data_dir, self.data_file_name)
self.data = self._read_csv(file_path=file_path, low_memory=False, na_values=['.'])
self.data = pd.read_csv(file_path, low_memory=False, na_values=['.'])

# Load microdata codebook
file_path = os.path.join(self.truth_data_dir, self.data_codebook_file_name)
codebook = self._read_csv(file_path=file_path, index_col='File order', low_memory=False)
codebook = pd.read_csv(file_path, index_col='File order', low_memory=False)
# Make a dict of column names (e.g. PBA) to labels (e.g. Principal building activity)
# and a dict of numeric enumerations to strings for non-numeric variables
var_name_to_label = {}
Expand Down
6 changes: 2 additions & 4 deletions postprocessing/comstockpostproc/comstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,6 @@ def add_geospatial_columns(self):
# Show the dataset size
logger.debug(f'Memory after add_geospatial_columns: {self.data.estimated_size()}')

def _read_csv(self, path, columns, dtypes): return pl.read_csv(path, columns=columns, dtypes=dtypes)

def add_ejscreen_columns(self):
# Add the EJ Screen data
if not 'nhgis_tract_gisjoin' in self.data:
Expand All @@ -749,7 +747,7 @@ def add_ejscreen_columns(self):
# Read the buildstock.csv and join columns onto annual results by building ID
file_name = 'EJSCREEN_Tract_2020_USPR.csv'
file_path = os.path.join(self.truth_data_dir, file_name)
ejscreen = self._read_csv(path=file_path, columns=col_def_names, dtypes={'ID': str})
ejscreen = pl.read_csv(file_path, columns=col_def_names, dtypes={'ID': str})

# Convert EJSCREEN census tract ID to gisjoin format
@lru_cache()
Expand Down Expand Up @@ -809,7 +807,7 @@ def add_cejst_columns(self):
# Read the buildstock.csv and join columns onto annual results by building ID
file_name = self.cejst_file_name
file_path = os.path.join(self.truth_data_dir, file_name)
cejst = self._read_csv(path=file_path, columns=col_def_names, dtypes=col_def_types)
cejst = pl.read_csv(file_path, columns=col_def_names, dtypes=col_def_types)

# Convert CEJST census tract ID to gisjoin format
@lru_cache()
Expand Down
25 changes: 14 additions & 11 deletions postprocessing/test/utility/mock_CBECS.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,27 @@ def __init__(self):
self.mock_read_delimited_truth_data_file_from_S3 = self.patcher_read_delimited_truth_data_file_from_S3.start()
self.mock_read_delimited_truth_data_file_from_S3.side_effect = self.mock_read_delimited_truth_data_file_from_S3_action

self.patcher__read_csv = patch('comstockpostproc.cbecs.CBECS._read_csv')
self.original_read_csv = pd.read_csv
self.patcher__read_csv = patch('pandas.read_csv')
self.mock__read_csv = self.patcher__read_csv.start()
self.mock__read_csv.side_effect = self.mock__read_csv_action

def mock_read_delimited_truth_data_file_from_S3_action(self, s3_file_path, delimiter):
logging.info('reading from path: {} with delimiter {}'.format(s3_file_path, delimiter))
return pd.DataFrame()

def mock__read_csv_action(self, **kwargs):
def mock__read_csv_action(self, *args ,**kwargs):
logging.info('Mocking read_csv from CBECS')
path = kwargs["file_path"]
path = args[0]
filePath = None
if "CBECS_2018_microdata.csv" in path:
filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata.csv"
filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata.csv"
elif "CBECS_2018_microdata_codebook.csv" in path:
filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata_codebook.csv"
filePath = "/truth_data/v01/EIA/CBECS/CBECS_2018_microdata_codebook.csv"

del kwargs["file_path"]
return pd.read_csv(filePath, **kwargs)
if filePath is None:
return self.original_read_csv(*args, **kwargs)
return self.original_read_csv(filePath, **kwargs)

def mock_read_delimited_truth_data_file_from_S3_action(self, s3_file_path, delimiter):
logging.info('reading from path: {} with delimiter {}'.format(s3_file_path, delimiter))
return pd.DataFrame()

def stop(self):
self.patcher.stop()
Expand Down
21 changes: 12 additions & 9 deletions postprocessing/test/utility/mock_comstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def __init__(self):
self.patcher_read_delimited_truth_data_file_from_S3 = patch('comstockpostproc.comstock.ComStock.read_delimited_truth_data_file_from_S3')
self.mock_read_delimited_truth_data_file_from_S3 = self.patcher_read_delimited_truth_data_file_from_S3.start()
self.mock_read_delimited_truth_data_file_from_S3.side_effect = self.mock_read_delimited_truth_data_file_from_S3_action

self.patcher__read_csv = patch('comstockpostproc.comstock.ComStock._read_csv')

self.original_read_csv = pl.read_csv
self.patcher__read_csv = patch('polars.read_csv')
self.mock__read_csv = self.patcher__read_csv.start()
self.mock__read_csv.side_effect = self.mock__read_csv_action

Expand All @@ -44,15 +45,17 @@ def mock_isfile_on_S3_action(self, bucket, file_path):
logging.info('Mocking isfile_on_S3')
return True

def mock__read_csv_action(self, **kwargs):
path = kwargs["path"]
def mock__read_csv_action(self, *args, **kwargs):
logging.info('Mocking read_csv from ComStock')
filePath = None
path = args[0]
if "EJSCREEN" in path:
filePath = "/truth_data/v01/EPA/EJSCREEN/EJSCREEN_Tract_2020_USPR.csv"
filePath = "/truth_data/v01/EPA/EJSCREEN/EJSCREEN_Tract_2020_USPR.csv"
elif "1.0-communities.csv" in path:
filePath = "/truth_data/v01/EPA/CEJST/1.0-communities.csv"
del kwargs["path"]

return pl.read_csv(filePath, **kwargs)
filePath = "/truth_data/v01/EPA/CEJST/1.0-communities.csv"
if not filePath:
return self.original_read_csv(*args, **kwargs)
return self.original_read_csv(filePath, **kwargs)

def stop(self):
self.patcher.stop()
Expand Down

0 comments on commit ece3ded

Please sign in to comment.