Skip to content

Commit

Permalink
fix: preserve Check options in schema statistics roundtrip
Browse files Browse the repository at this point in the history
Signed-off-by: alexismanuel <alexis.manuelpro@gmail.com>
  • Loading branch information
alexismanuel committed Nov 16, 2024
1 parent ea4538d commit d09f3d4
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 83 deletions.
84 changes: 71 additions & 13 deletions pandera/io/pandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,28 @@ def handle_stat_dtype(stat):

return stat

# for unary checks, return a single value instead of a dictionary
if len(check_stats) == 1:
return handle_stat_dtype(list(check_stats.values())[0])
# Extract check options if they exist
check_options = (
check_stats.pop("options", {}) if isinstance(check_stats, dict) else {}
)

# Handle unary checks
if isinstance(check_stats, dict) and len(check_stats) == 1:
value = handle_stat_dtype(list(check_stats.values())[0])
if check_options:
return {"value": value, "options": check_options}
return value

# otherwise return a dictionary of keyword args needed to create the Check
serialized_check_stats = {}
for arg, stat in check_stats.items():
serialized_check_stats[arg] = handle_stat_dtype(stat)
return serialized_check_stats
# Handle dictionary case
if isinstance(check_stats, dict):
serialized_check_stats = {}
for arg, stat in check_stats.items():
serialized_check_stats[arg] = handle_stat_dtype(stat)
if check_options:
serialized_check_stats["options"] = check_options
return serialized_check_stats

return handle_stat_dtype(check_stats)


def _serialize_dataframe_stats(dataframe_checks):
Expand Down Expand Up @@ -178,6 +191,8 @@ def serialize_schema(dataframe_schema):


def _deserialize_check_stats(check, serialized_check_stats, dtype=None):
"""Deserialize check statistics and reconstruct check with options."""

def handle_stat_dtype(stat):
try:
if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype):
Expand All @@ -189,15 +204,35 @@ def handle_stat_dtype(stat):
return stat
return stat

# Extract options if they exist
options = {}
if isinstance(serialized_check_stats, dict):
# handle case where serialized check stats are in the form of a
# dictionary mapping Check arg names to values.
options = serialized_check_stats.pop("options", {})
# Handle special case for unary checks with options
if (
"value" in serialized_check_stats
and len(serialized_check_stats) == 1
):
serialized_check_stats = serialized_check_stats["value"]

# Create check with original logic
if isinstance(serialized_check_stats, dict):
check_stats = {}
for arg, stat in serialized_check_stats.items():
check_stats[arg] = handle_stat_dtype(stat)
return check(**check_stats)
# otherwise assume unary check function signature
return check(handle_stat_dtype(serialized_check_stats))
check_instance = check(**check_stats)
else:
# otherwise assume unary check function signature
check_instance = check(handle_stat_dtype(serialized_check_stats))

# Apply options if they exist
if options:
for option_name, option_value in options.items():
setattr(check_instance, option_name, option_value)

return check_instance


def _deserialize_component_stats(serialized_component_stats):
Expand Down Expand Up @@ -447,6 +482,7 @@ def to_json(dataframe_schema, target=None, **kwargs):


def _format_checks(checks_dict):
"""Format checks into string representation including options."""
if checks_dict is None:
return "None"

Expand All @@ -457,11 +493,33 @@ def _format_checks(checks_dict):
f"Check {check_name} cannot be serialized. "
"This check will be ignored"
)
else:
continue

# Handle options separately
options = (
check_kwargs.pop("options", {})
if isinstance(check_kwargs, dict)
else {}
)

# Format main check arguments
if isinstance(check_kwargs, dict):
args = ", ".join(
f"{k}={v.__repr__()}" for k, v in check_kwargs.items()
)
checks.append(f"Check.{check_name}({args})")
else:
args = check_kwargs.__repr__()

# Add options to arguments if they exist
if options:
if args:
args += ", "
args += ", ".join(
f"{k}={v.__repr__()}" for k, v in options.items()
)

checks.append(f"Check.{check_name}({args})")

return f"[{', '.join(checks)}]"


Expand Down
52 changes: 43 additions & 9 deletions pandera/schema_statistics/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,29 @@ def _index_stats(index_level):


def parse_check_statistics(check_stats: Union[Dict[str, Any], None]):
"""Convert check statistics to a list of Check objects."""
"""Convert check statistics to a list of Check objects, including their options."""
if check_stats is None:
return None
checks = []
for check_name, stats in check_stats.items():
check = getattr(Check, check_name)
try:
checks.append(check(**stats))
# Extract options if present
if isinstance(stats, dict):
options = (
stats.pop("options", {}) if "options" in stats else {}
)
if stats: # If there are remaining stats
check_instance = check(**stats)
else: # Handle case where all stats were in options
check_instance = check()
# Apply options to the check instance
for option_name, option_value in options.items():
setattr(check_instance, option_name, option_value)
checks.append(check_instance)
else:
# Handle unary check case
checks.append(check(stats))
except TypeError:
# if stats cannot be unpacked as key-word args, assume unary check.
checks.append(check(stats))
Expand Down Expand Up @@ -142,9 +157,10 @@ def get_series_schema_statistics(series_schema):


def parse_checks(checks) -> Union[Dict[str, Any], None]:
"""Convert Check object to check statistics."""
"""Convert Check object to check statistics including options."""
check_statistics = {}
_check_memo = {}

for check in checks:
if check not in Check:
warnings.warn(
Expand All @@ -154,28 +170,46 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]:
)
continue

check_statistics[check.name] = (
{} if check.statistics is None else check.statistics
)
# Get base statistics
base_stats = {} if check.statistics is None else check.statistics

# Collect check options
check_options = {
"raise_warning": check.raise_warning,
"n_failure_cases": check.n_failure_cases,
"ignore_na": check.ignore_na,
}

# Filter out None values from options
check_options = {
k: v for k, v in check_options.items() if v is not None
}

# Combine statistics with options
check_statistics[check.name] = base_stats
if check_options:
check_statistics[check.name]["options"] = check_options

_check_memo[check.name] = check

# raise ValueError on incompatible checks
# Check for incompatible checks
if (
"greater_than_or_equal_to" in check_statistics
and "less_than_or_equal_to" in check_statistics
):
min_value = check_statistics.get(
"greater_than_or_equal_to", float("-inf")
)["min_value"]
).get("min_value", float("-inf"))
max_value = check_statistics.get(
"less_than_or_equal_to", float("inf")
)["max_value"]
).get("max_value", float("inf"))
if min_value > max_value:
raise ValueError(
f"checks {_check_memo['greater_than_or_equal_to']} "
f"and {_check_memo['less_than_or_equal_to']} are incompatible, reason: "
f"min value {min_value} > max value {max_value}"
)

return check_statistics if check_statistics else None


Expand Down
Loading

0 comments on commit d09f3d4

Please sign in to comment.