Skip to content

Commit

Permalink
Refactor checker class (#1848)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbowen-usds committed Aug 31, 2022
1 parent 8f4346c commit 0387890
Showing 1 changed file with 37 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,59 +118,53 @@ def test_tract_equality(tiles_df, final_score_df):


@dataclass
class DTypeComparison:
final_score_dtype: np.dtype
tile_dtype: np.dtype
class ColumnValueComparison:
final_score_column: pd.Series
tiles_column: pd.Series
col_name: str

def __post_init__(self):
self._is_dtype_ok = self.final_score_dtype == self.tile_dtype
self._is_dtype_ok = (
self.final_score_column.dtype == self.tiles_column.dtype
)
self._is_value_ok = False
if self._is_dtype_ok:
if self.final_score_column.dtype == np.dtype("float64"):
self._is_value_ok = np.allclose(
self.final_score_column,
self.tiles_column,
atol=float(f"1e-{constants.TILES_ROUND_NUM_DECIMALS}"),
equal_nan=True,
)
else:
self._is_value_ok = (
self.final_score_column.dropna()
== self.tiles_column.dropna()
).all()

def __bool__(self) -> bool:
return self._is_dtype_ok
return self._is_dtype_ok and bool(self._is_value_ok)

@property
def error_message(self) -> Optional[str]:
if not self._is_dtype_ok:
return (
f"Column {self.col_name} dtype mismatch: "
f"score_df: {self.final_score_dtype}, "
f"tile_df: {self.tile_dtype}"
)
return None


@dataclass
class ColumnValueComparison:
final_score_column: pd.Series
tiles_column: pd.Series
col_name: str

def __post_init__(self):
if self.final_score_column.dtype == np.dtype("float64"):
self._is_value_ok = np.allclose(
self.final_score_column,
self.tiles_column,
atol=float(f"1e-{constants.TILES_ROUND_NUM_DECIMALS}"),
equal_nan=True,
f"score_df: {self.final_score_column.dtype}, "
f"tile_df: {self.tiles_column.dtype}"
)
else:
self._is_value_ok = (
self.final_score_column.dropna() == self.tiles_column.dropna()
).all()

def __bool__(self) -> bool:
return bool(self._is_value_ok)

@property
def error_message(self) -> Optional[str]:
if not self._is_value_ok:
elif not self._is_value_ok:
return f"Column {self.col_name} value mismatch"
return None


def test_for_column_fidelitiy_from_score(tiles_df, final_score_df):

# Verify the following:
# * Shape and tracts match between score csv and tile csv
# * If you rename score CSV columns, you are able to make the tile csv
# * The dtypes and values of every renamed score column is "equal" to
# every tile column
# * Because tiles use rounded floats, we use close with a tolerance
assert (
set(TILES_SCORE_COLUMNS.values()) - set(tiles_df.columns) == set()
), "Some TILES_SCORE_COLUMNS are missing from the tiles dataframe"
Expand Down Expand Up @@ -201,22 +195,16 @@ def test_for_column_fidelitiy_from_score(tiles_df, final_score_df):
# Are all the dtypes and values the same?
comparisons = []
for col_name in final_score_df.columns:
dtype_comparison = DTypeComparison(
final_score_dtype=final_score_df.dtypes.loc[col_name],
tile_dtype=tiles_df.dtypes.loc[col_name],
col_name=col_name,
value_comparison = ColumnValueComparison(
final_score_df[col_name], tiles_df[col_name], col_name
)
comparisons.append(dtype_comparison)
if dtype_comparison:
value_comparison = ColumnValueComparison(
final_score_df[col_name], tiles_df[col_name], col_name
)
comparisons.append(value_comparison)
comparisons.append(value_comparison)
errors = [comp for comp in comparisons if not comp]
error_message = "\n".join(error.error_message for error in errors)
assert not errors, error_message


# For each data point that we visualize, we want to confirm that
# (1) the column is represented in tiles_columns
# (2) the column values are of the TYPE they are supposed to be
def test_for_state_names(tiles_df):
states = tiles_df.SF.value_counts(dropna=False).index
breakpoint()
assert False

0 comments on commit 0387890

Please sign in to comment.