diff --git a/.pylintrc b/.pylintrc index fac105f22..94cf9acca 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,6 +15,13 @@ good-names= fp, bar, _IS_INFERRED, + eq, + ne, + gt, + ge, + lt, + le, + dt [MESSAGES CONTROL] disable= diff --git a/pandera/__init__.py b/pandera/__init__.py index 0ac7af54d..8ee0464e4 100644 --- a/pandera/__init__.py +++ b/pandera/__init__.py @@ -60,10 +60,13 @@ pandas_version, ) +import pandera.backends + from pandera.schema_inference.pandas import infer_schema from pandera.decorators import check_input, check_io, check_output, check_types from pandera.version import __version__ + if platform.system() != "Windows": # pylint: disable=ungrouped-imports from pandera.dtypes import Complex256, Float128 diff --git a/pandera/backends/__init__.py b/pandera/backends/__init__.py index e69de29bb..5767df682 100644 --- a/pandera/backends/__init__.py +++ b/pandera/backends/__init__.py @@ -0,0 +1,7 @@ +"""Pandera backends.""" + +# ensure that base builtin checks and hypothesis are registered +import pandera.backends.base.builtin_checks +import pandera.backends.base.builtin_hypotheses + +import pandera.backends.pandas diff --git a/pandera/backends/base.py b/pandera/backends/base/__init__.py similarity index 91% rename from pandera/backends/base.py rename to pandera/backends/base/__init__.py index 4452157a1..edfc3661d 100644 --- a/pandera/backends/base.py +++ b/pandera/backends/base/__init__.py @@ -1,12 +1,12 @@ -"""Base functions for Parsing, Validation, and Error Reporting Backends. +"""Base classes for parsing, validation, and error Reporting Backends. -This class should implement a common interface of operations needed for +These classes implement a common interface of operations needed for data validation. These operations are exposed as methods that are composed together to implement the pandera schema specification. """ from abc import ABC -from typing import Optional +from typing import Any, Dict, List, Optional class BaseSchemaBackend(ABC): @@ -100,6 +100,12 @@ def check_dtype(self, check_obj, schema): """Core check that checks the data type of a check object.""" raise NotImplementedError + def failure_cases_metadata( + self, schema_name: str, schema_errors: List[Dict[str, Any]] + ): + """Get failure cases metadata for lazy validation.""" + raise NotImplementedError + class BaseCheckBackend(ABC): """Abstract base class for a check backend implementation.""" diff --git a/pandera/backends/base/builtin_checks.py b/pandera/backends/base/builtin_checks.py new file mode 100644 index 000000000..65ec40002 --- /dev/null +++ b/pandera/backends/base/builtin_checks.py @@ -0,0 +1,98 @@ +# pylint: disable=missing-function-docstring +"""Built-in check functions base implementation. + +This module contains check function abstract definitions that correspond to +the pandera.core.base.checks.Check methods. These functions do not actually +implement any validation logic and serve as the entrypoint for dispatching +specific implementations based on the data object type, e.g. +`pandas.DataFrame`s. +""" + +import re +from typing import Any, Iterable, TypeVar, Union + +from pandera.core.checks import Check + + +T = TypeVar("T") + + +@Check.register_builtin_check_fn +def equal_to(data: Any, value: Any) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def not_equal_to(data: Any, value: Any) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def greater_than(data: Any, min_value: Any) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def greater_than_or_equal_to(data: Any, min_value: Any) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def less_than(data: Any, max_value: Any) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def less_than_or_equal_to(data: Any, max_value: Any) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def in_range( + data: Any, + min_value: T, + max_value: T, + include_min: bool = True, + include_max: bool = True, +) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def isin(data: Any, allowed_values: Iterable) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def notin(data: Any, forbidden_values: Iterable) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def str_matches(data: Any, pattern: Union[str, re.Pattern]) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def str_contains(data: Any, pattern: Union[str, re.Pattern]) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def str_startswith(data: Any, string: str) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def str_endswith(data: Any, string: str) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def str_length(data: Any, min_value: int = None, max_value: int = None) -> Any: + raise NotImplementedError + + +@Check.register_builtin_check_fn +def unique_values_eq(data: Any, values: Iterable) -> Any: + raise NotImplementedError diff --git a/pandera/backends/base/builtin_hypotheses.py b/pandera/backends/base/builtin_hypotheses.py new file mode 100644 index 000000000..488488585 --- /dev/null +++ b/pandera/backends/base/builtin_hypotheses.py @@ -0,0 +1,31 @@ +# pylint: disable=missing-function-docstring +"""Built-in hypothesis functions base implementation. + +This module contains hypothesis function abstract definitions that +correspond to the pandera.core.base.checks.Check methods. These functions do not +actually implement any validation logic and serve as the entrypoint for +dispatching specific implementations based on the data object type, e.g. +`pandas.DataFrame`s. +""" + +from typing import Any, Tuple + +from pandera.core.hypotheses import Hypothesis + + +@Hypothesis.register_builtin_check_fn +def two_sample_ttest( + *samples: Tuple[Any, ...], + equal_var: bool = True, + nan_policy: str = "propagate", +): + raise NotImplementedError + + +@Hypothesis.register_builtin_check_fn +def one_sample_ttest( + *samples: Tuple[Any, ...], + popmean: float, + nan_policy: str = "propagate", +): + raise NotImplementedError diff --git a/pandera/backends/pandas/__init__.py b/pandera/backends/pandas/__init__.py index e69de29bb..b77c0f56e 100644 --- a/pandera/backends/pandas/__init__.py +++ b/pandera/backends/pandas/__init__.py @@ -0,0 +1,38 @@ +"""Pandas backend implementation for schemas and checks.""" + +import pandas as pd + +import pandera.typing +from pandera.core.checks import Check +from pandera.core.hypotheses import Hypothesis + +from pandera.backends.pandas.checks import PandasCheckBackend +from pandera.backends.pandas.hypotheses import PandasHypothesisBackend +from pandera.backends.pandas import builtin_checks, builtin_hypotheses + + +data_types = [pd.DataFrame, pd.Series] + +if pandera.typing.dask.DASK_INSTALLED: + import dask.dataframe as dd + + data_types.extend([dd.DataFrame, dd.Series]) + +if pandera.typing.modin.MODIN_INSTALLED: + import modin.pandas as mpd + + data_types.extend([mpd.DataFrame, mpd.Series]) + +if pandera.typing.pyspark.PYSPARK_INSTALLED: + import pyspark.pandas as ps + + data_types.extend([ps.DataFrame, ps.Series]) + +if pandera.typing.geopandas.GEOPANDAS_INSTALLED: + import geopandas as gpd + + data_types.extend([gpd.GeoDataFrame, gpd.GeoSeries]) + +for t in data_types: + Check.register_backend(t, PandasCheckBackend) + Hypothesis.register_backend(t, PandasHypothesisBackend) diff --git a/pandera/backends/pandas/array.py b/pandera/backends/pandas/array.py index 3af9188f0..58869bc53 100644 --- a/pandera/backends/pandas/array.py +++ b/pandera/backends/pandas/array.py @@ -102,7 +102,9 @@ def validate( if lazy and error_handler.collected_errors: raise SchemaErrors( - schema, error_handler.collected_errors, check_obj + schema=schema, + schema_errors=error_handler.collected_errors, + data=check_obj, ) return check_obj diff --git a/pandera/backends/pandas/base.py b/pandera/backends/pandas/base.py index 2aa6df99e..fc57cb48f 100644 --- a/pandera/backends/pandas/base.py +++ b/pandera/backends/pandas/base.py @@ -2,6 +2,8 @@ import warnings from typing import ( + Any, + Dict, FrozenSet, Iterable, List, @@ -17,10 +19,12 @@ from pandera.backends.pandas.error_formatters import ( format_generic_error_message, format_vectorized_error_message, + consolidate_failure_cases, + summarize_failure_cases, reshape_failure_cases, scalar_failure_case, ) -from pandera.errors import SchemaError +from pandera.errors import SchemaError, FailureCaseMetadata class ColumnInfo(NamedTuple): @@ -118,3 +122,19 @@ def run_check( check_output=check_result.check_output, ) return check_result.check_passed + + def failure_cases_metadata( + self, + schema_name: str, + schema_errors: List[Dict[str, Any]], + ) -> FailureCaseMetadata: + """Create failure cases metadata required for SchemaErrors exception.""" + failure_cases = consolidate_failure_cases(schema_errors) + message, error_counts = summarize_failure_cases( + schema_name, schema_errors, failure_cases + ) + return FailureCaseMetadata( + failure_cases=failure_cases, + message=message, + error_counts=error_counts, + ) diff --git a/pandera/core/pandas/checks.py b/pandera/backends/pandas/builtin_checks.py similarity index 69% rename from pandera/core/pandas/checks.py rename to pandera/backends/pandas/builtin_checks.py index e9ab11dc3..3477f4a5b 100644 --- a/pandera/core/pandas/checks.py +++ b/pandera/backends/pandas/builtin_checks.py @@ -7,21 +7,21 @@ import pandas as pd import pandera.strategies as st -from pandera.core.extensions import register_check +from pandera.core.extensions import register_builtin_check from pandera.typing.modin import MODIN_INSTALLED from pandera.typing.pyspark import PYSPARK_INSTALLED -if MODIN_INSTALLED and not PYSPARK_INSTALLED: +if MODIN_INSTALLED and not PYSPARK_INSTALLED: # pragma: no cover import modin.pandas as mpd PandasData = Union[pd.Series, pd.DataFrame, mpd.Series, mpd.DataFrame] -elif not MODIN_INSTALLED and PYSPARK_INSTALLED: +elif not MODIN_INSTALLED and PYSPARK_INSTALLED: # pragma: no cover import pyspark.pandas as ppd PandasData = Union[pd.Series, pd.DataFrame, ppd.Series, ppd.DataFrame] # type: ignore[misc] -elif MODIN_INSTALLED and PYSPARK_INSTALLED: +elif MODIN_INSTALLED and PYSPARK_INSTALLED: # pragma: no cover import modin.pandas as mpd import pyspark.pandas as ppd @@ -33,14 +33,14 @@ ppd.Series, ppd.DataFrame, ] -else: +else: # pragma: no cover PandasData = Union[pd.Series, pd.DataFrame] # type: ignore[misc] T = TypeVar("T") -@register_check( +@register_builtin_check( aliases=["eq"], strategy=st.eq_strategy, error="equal_to({value})", @@ -54,7 +54,7 @@ def equal_to(data: PandasData, value: Any) -> PandasData: return data == value -@register_check( +@register_builtin_check( aliases=["ne"], strategy=st.ne_strategy, error="not_equal_to({value})", @@ -68,15 +68,7 @@ def not_equal_to(data: PandasData, value: Any) -> PandasData: return data != value -def gt_ge_pre_init_hook(statistics_kwargs): - """Pre-init hook for greater than/greater or equal to check.""" - if statistics_kwargs["min_value"] is None: - raise ValueError("min_value must not be None") - return statistics_kwargs - - -@register_check( - pre_init_hook=gt_ge_pre_init_hook, +@register_builtin_check( aliases=["gt"], strategy=st.gt_strategy, error="greater_than({min_value})", @@ -93,8 +85,7 @@ def greater_than(data: PandasData, min_value: Any) -> PandasData: return data > min_value -@register_check( - pre_init_hook=gt_ge_pre_init_hook, +@register_builtin_check( aliases=["ge"], strategy=st.ge_strategy, error="greater_than_or_equal_to({min_value})", @@ -109,15 +100,7 @@ def greater_than_or_equal_to(data: PandasData, min_value: Any) -> PandasData: return data >= min_value -def lt_le_pre_init_hook(statistics_kwargs): - """Pre-init hook for less than/less than or equal to check.""" - if statistics_kwargs["max_value"] is None: - raise ValueError("max_value must not be None") - return statistics_kwargs - - -@register_check( - pre_init_hook=lt_le_pre_init_hook, +@register_builtin_check( aliases=["lt"], strategy=st.lt_strategy, error="less_than({max_value})", @@ -134,8 +117,7 @@ def less_than(data: PandasData, max_value: Any) -> PandasData: return data < max_value -@register_check( - pre_init_hook=lt_le_pre_init_hook, +@register_builtin_check( aliases=["le"], strategy=st.le_strategy, error="less_than_or_equal_to({max_value})", @@ -152,29 +134,7 @@ def less_than_or_equal_to(data: PandasData, max_value: Any) -> PandasData: return data <= max_value -def in_range_pre_init_hook(statistics_kwargs): - """Pre-init hook for ``in_range`` check.""" - min_value = statistics_kwargs["min_value"] - max_value = statistics_kwargs["max_value"] - include_min = statistics_kwargs["include_min"] - include_max = statistics_kwargs["include_max"] - - if min_value is None: - raise ValueError("min_value must not be None") - if max_value is None: - raise ValueError("max_value must not be None") - if max_value < min_value or ( # type: ignore - min_value == max_value and (not include_min or not include_max) - ): - raise ValueError( - f"The combination of min_value = {min_value} and " - f"max_value = {max_value} defines an empty interval!" - ) - return statistics_kwargs - - -@register_check( - pre_init_hook=in_range_pre_init_hook, +@register_builtin_check( aliases=["between"], strategy=st.in_range_strategy, error="in_range({min_value}, {max_value})", @@ -208,20 +168,7 @@ def in_range( return left_op(min_value, data) & right_op(max_value, data) # type: ignore -def isin_pre_init_hook(statistics_kwargs): - """Pre-init hook for ``isin`` check.""" - allowed_values = statistics_kwargs["allowed_values"] - try: - allowed_values = frozenset(allowed_values) - except TypeError as exc: - raise ValueError( - f"Argument allowed_values must be iterable. Got {allowed_values}" - ) from exc - return {"allowed_values": allowed_values} - - -@register_check( - pre_init_hook=isin_pre_init_hook, +@register_builtin_check( strategy=st.isin_strategy, error="isin({allowed_values})", ) @@ -241,20 +188,7 @@ def isin(data: PandasData, allowed_values: Iterable) -> PandasData: return data.isin(allowed_values) -def notin_pre_init_hook(statistics_kwargs): - """Pre-init hook for ``notin`` check.""" - forbidden_values = statistics_kwargs["forbidden_values"] - try: - forbidden_values = frozenset(forbidden_values) - except TypeError as exc: - raise ValueError( - f"Argument forbidden_values must be iterable. Got {forbidden_values}" - ) from exc - return {"forbidden_values": forbidden_values} - - -@register_check( - pre_init_hook=notin_pre_init_hook, +@register_builtin_check( strategy=st.notin_strategy, error="notin({forbidden_values})", ) @@ -274,20 +208,7 @@ def notin(data: PandasData, forbidden_values: Iterable) -> PandasData: return ~data.isin(forbidden_values) -def str_regex_pre_init_hook(statistics_kwargs): - """Pre-init hook for string regex checks.""" - pattern = statistics_kwargs["pattern"] - try: - regex = re.compile(pattern) - except TypeError as exc: - raise ValueError( - f'pattern="{pattern}" cannot be compiled as regular expression' - ) from exc - return {"pattern": regex} - - -@register_check( - pre_init_hook=str_regex_pre_init_hook, +@register_builtin_check( strategy=st.str_matches_strategy, error="str_matches('{pattern}')", ) @@ -303,8 +224,7 @@ def str_matches( return data.str.match(cast(str, pattern), na=False) -@register_check( - pre_init_hook=str_regex_pre_init_hook, +@register_builtin_check( strategy=st.str_contains_strategy, error="str_contains('{pattern}')", ) @@ -320,7 +240,7 @@ def str_contains( return data.str.contains(cast(str, pattern), na=False) -@register_check( +@register_builtin_check( strategy=st.str_startswith_strategy, error="str_startswith('{string}')", ) @@ -333,7 +253,7 @@ def str_startswith(data: PandasData, string: str) -> PandasData: return data.str.startswith(string, na=False) -@register_check( +@register_builtin_check( strategy=st.str_endswith_strategy, error="str_endswith('{string}')" ) def str_endswith(data: PandasData, string: str) -> PandasData: @@ -345,20 +265,7 @@ def str_endswith(data: PandasData, string: str) -> PandasData: return data.str.endswith(string, na=False) -def str_length_pre_init_hook(statistics_kwargs): - """Pre-init hook for ``str_length`` check.""" - min_value = statistics_kwargs["min_value"] - max_value = statistics_kwargs["max_value"] - if min_value is None and max_value is None: - raise ValueError( - "At least a minimum or a maximum need to be specified. Got " - "None." - ) - return statistics_kwargs - - -@register_check( - pre_init_hook=str_length_pre_init_hook, +@register_builtin_check( strategy=st.str_length_strategy, error="str_length({min_value}, {max_value})", ) @@ -385,20 +292,7 @@ def str_length( return (str_len <= max_value) & (str_len >= min_value) -def unique_values_eq_init_hook(statistics_kwargs): - """Pre-init hook for ``unique_values`` check.""" - values = statistics_kwargs["values"] - try: - values = frozenset(values) - except TypeError as exc: - raise ValueError( - f"Argument values must be iterable. Got {values}" - ) from exc - return {"values": values} - - -@register_check( - pre_init_hook=unique_values_eq_init_hook, +@register_builtin_check( error="unique_values_eq({values})", ) def unique_values_eq(data: PandasData, values: Iterable): diff --git a/pandera/backends/pandas/builtin_hypotheses.py b/pandera/backends/pandas/builtin_hypotheses.py new file mode 100644 index 000000000..e023af163 --- /dev/null +++ b/pandera/backends/pandas/builtin_hypotheses.py @@ -0,0 +1,49 @@ +# pylint: disable=missing-function-docstring +"""Pandas implementation of built-in hypotheses.""" + +from typing import Tuple + +from pandera.backends.pandas.builtin_checks import PandasData +from pandera.backends.pandas.hypotheses import HAS_SCIPY +from pandera.core.extensions import register_builtin_hypothesis + + +if HAS_SCIPY: + from scipy import stats + + +@register_builtin_hypothesis( + error="failed two sample ttest between '{sample1}' and '{sample2}'", + samples_kwtypes={"sample1": str, "sample2": str}, +) +def two_sample_ttest( + *samples: Tuple[PandasData, ...], + equal_var: bool = True, + nan_policy: str = "propagate", +) -> Tuple[float, float]: + assert ( + len(samples) == 2 + ), "Expected two sample ttest data to contain exactly two samples" + return stats.ttest_ind( + samples[0], + samples[1], + equal_var=equal_var, + nan_policy=nan_policy, + ) + + +@register_builtin_hypothesis( + error="failed one sample ttest for column '{sample}'", + samples_kwtypes={"sample": str}, +) +def one_sample_ttest( + *samples: Tuple[PandasData, ...], + popmean: float, + nan_policy: str = "propagate", +) -> Tuple[float, float]: + assert ( + len(samples) == 1 + ), "Expected one sample ttest data to contain only one sample" + return stats.ttest_1samp( + samples[0], popmean=popmean, nan_policy=nan_policy + ) diff --git a/pandera/backends/pandas/components.py b/pandera/backends/pandas/components.py index 478cb0119..e5ec94df9 100644 --- a/pandera/backends/pandas/components.py +++ b/pandera/backends/pandas/components.py @@ -95,7 +95,9 @@ def validate_column(check_obj, column_name): if lazy and error_handler.collected_errors: raise SchemaErrors( - schema, error_handler.collected_errors, check_obj + schema=schema, + schema_errors=error_handler.collected_errors, + data=check_obj, ) return check_obj @@ -318,7 +320,9 @@ def coerce_dtype( # type: ignore[override] if error_handler.collected_errors: raise SchemaErrors( - schema, error_handler.collected_errors, check_obj + schema=schema, + schema_errors=error_handler.collected_errors, + data=check_obj, ) multiindex_cls = pd.MultiIndex @@ -468,7 +472,11 @@ def to_dataframe(multiindex): schema_error_dict["error"] = error schema_error_dicts.append(schema_error_dict) - raise SchemaErrors(schema, schema_error_dicts, check_obj) from err + raise SchemaErrors( + schema=schema, + schema_errors=schema_error_dicts, + data=check_obj, + ) from err assert is_table(validation_result) return check_obj diff --git a/pandera/backends/pandas/container.py b/pandera/backends/pandas/container.py index 2d2cffad0..ba4629447 100644 --- a/pandera/backends/pandas/container.py +++ b/pandera/backends/pandas/container.py @@ -111,9 +111,9 @@ def validate( if error_handler.collected_errors: raise SchemaErrors( - schema, - error_handler.collected_errors, - check_obj, + schema=schema, + schema_errors=error_handler.collected_errors, + data=check_obj, ) return check_obj @@ -350,7 +350,9 @@ def coerce_dtype( # raise SchemaErrors if this method is called without an # error_handler raise SchemaErrors( - schema, _error_handler.collected_errors, check_obj + schema=schema, + schema_errors=_error_handler.collected_errors, + data=check_obj, ) return check_obj @@ -434,7 +436,11 @@ def _try_coercion(coerce_fn, obj): obj.index = coerced_index if error_handler.collected_errors: - raise SchemaErrors(schema, error_handler.collected_errors, obj) + raise SchemaErrors( + schema=schema, + schema_errors=error_handler.collected_errors, + data=obj, + ) return obj diff --git a/pandera/backends/pandas/error_formatters.py b/pandera/backends/pandas/error_formatters.py index d004d84b3..cc0c5a2ea 100644 --- a/pandera/backends/pandas/error_formatters.py +++ b/pandera/backends/pandas/error_formatters.py @@ -1,6 +1,7 @@ """Make schema error messages human-friendly.""" -from typing import Union +from collections import defaultdict +from typing import Any, Dict, List, Tuple, Union import pandas as pd @@ -137,3 +138,168 @@ def _multiindex_to_frame(df): if pandas_version().release >= (1, 5, 0): return df.index.to_frame(allow_duplicates=True) return df.index.to_frame().drop_duplicates() + + +def consolidate_failure_cases(schema_errors: List[Dict[str, Any]]): + """Consolidate schema error dicts to produce data for error message.""" + check_failure_cases = [] + + column_order = [ + "schema_context", + "column", + "check", + "check_number", + "failure_case", + "index", + ] + + for schema_error_dict in schema_errors: + reason_code = schema_error_dict["reason_code"] + err = schema_error_dict["error"] + + check_identifier = ( + None + if err.check is None + else err.check + if isinstance(err.check, str) + else err.check.error + if err.check.error is not None + else err.check.name + if err.check.name is not None + else str(err.check) + ) + + if err.failure_cases is not None: + if "column" in err.failure_cases: + column = err.failure_cases["column"] + else: + column = ( + err.schema.name + if reason_code == "schema_component_check" + else None + ) + + failure_cases = err.failure_cases.assign( + schema_context=err.schema.__class__.__name__, + check=check_identifier, + check_number=err.check_index, + # if the column key is a tuple (for MultiIndex column + # names), explicitly wrap `column` in a list of the + # same length as the number of failure cases. + column=( + [column] * err.failure_cases.shape[0] + if isinstance(column, tuple) + else column + ), + ) + check_failure_cases.append(failure_cases[column_order]) + + # NOTE: this is a hack to support pyspark.pandas and modin + concat_fn = pd.concat # type: ignore + if any( + type(x).__module__.startswith("pyspark.pandas") + for x in check_failure_cases + ): + # pylint: disable=import-outside-toplevel + import pyspark.pandas as ps + + concat_fn = ps.concat # type: ignore + check_failure_cases = [ + x if isinstance(x, ps.DataFrame) else ps.DataFrame(x) + for x in check_failure_cases + ] + elif any( + type(x).__module__.startswith("modin.pandas") + for x in check_failure_cases + ): + # pylint: disable=import-outside-toplevel + import modin.pandas as mpd + + concat_fn = mpd.concat # type: ignore + check_failure_cases = [ + x if isinstance(x, mpd.DataFrame) else mpd.DataFrame(x) + for x in check_failure_cases + ] + + return ( + concat_fn(check_failure_cases) + .reset_index(drop=True) + .sort_values("schema_context", ascending=False) + ) + + +SCHEMA_ERRORS_SUFFIX = """ + +Usage Tip +--------- + +Directly inspect all errors by catching the exception: + +``` +try: + schema.validate(dataframe, lazy=True) +except SchemaErrors as err: + err.failure_cases # dataframe of schema errors + err.data # invalid dataframe +``` +""" + + +def summarize_failure_cases( + schema_name: str, + schema_errors: List[Dict[str, Any]], + failure_cases: pd.DataFrame, +) -> Tuple[str, Dict[str, int]]: + """Format error message.""" + + error_counts = defaultdict(int) # type: ignore + for schema_error_dict in schema_errors: + reason_code = schema_error_dict["reason_code"] + error_counts[reason_code] += 1 + + msg = ( + f"Schema {schema_name}: A total of " + f"{sum(error_counts.values())} schema errors were found.\n" + ) + + msg += "\nError Counts" + msg += "\n------------\n" + for k, v in error_counts.items(): + msg += f"- {k}: {v}\n" + + def agg_failure_cases(df): + # Note: hack to support unhashable types, proper solution that only transforms + # when requires https://github.com/unionai-oss/pandera/issues/260 + df.failure_case = df.failure_case.astype(str) + # NOTE: this is a hack to add modin support + if type(df).__module__.startswith("modin.pandas"): + return ( + df.groupby(["schema_context", "column", "check"]) + .agg({"failure_case": "unique"}) + .failure_case + ) + return df.groupby( + ["schema_context", "column", "check"] + ).failure_case.unique() + + summarized_failure_cases = ( + failure_cases.fillna({"column": ""}) + .pipe(agg_failure_cases) + .rename("failure_cases") + .to_frame() + .assign(n_failure_cases=lambda df: df.failure_cases.map(len)) + ) + index_labels = [ + summarized_failure_cases.index.names.index(name) + for name in ["schema_context", "column"] + ] + summarized_failure_cases = summarized_failure_cases.sort_index( + level=index_labels, + ascending=[False, True], + ) + msg += "\nSchema Error Summary" + msg += "\n--------------------\n" + with pd.option_context("display.max_colwidth", 100): + msg += summarized_failure_cases.to_string() + msg += SCHEMA_ERRORS_SUFFIX + return msg, error_counts diff --git a/pandera/backends/pandas/hypotheses.py b/pandera/backends/pandas/hypotheses.py index 3108ea9b6..6030195a3 100644 --- a/pandera/backends/pandas/hypotheses.py +++ b/pandera/backends/pandas/hypotheses.py @@ -92,10 +92,14 @@ def _hypothesis_check(self, check_obj): :param check_obj: object to validate. """ if is_field(check_obj): - return self.relationship(*self.check._check_fn(check_obj)) + return self.relationship( + *self.check._check_fn(check_obj, **self.check._check_kwargs) + ) _check_obj = [check_obj.get(s) for s in self.check.samples] - return self.relationship(*self.check._check_fn(*_check_obj)) + return self.relationship( + *self.check._check_fn(*_check_obj, **self.check._check_kwargs) + ) @property def is_one_sample_test(self): diff --git a/pandera/core/__init__.py b/pandera/core/__init__.py index fdd99fad0..4a0454729 100644 --- a/pandera/core/__init__.py +++ b/pandera/core/__init__.py @@ -2,42 +2,3 @@ This module contains the schema specifications for all supported data objects. """ - -import pandas as pd - -import pandera.typing -from pandera.backends.pandas.checks import PandasCheckBackend -from pandera.backends.pandas.hypotheses import PandasHypothesisBackend - -from pandera.core.checks import Check -from pandera.core.hypotheses import Hypothesis -from pandera.core.pandas import checks, hypotheses -from pandera.core.pandas.array import SeriesSchema -from pandera.core.pandas.components import Column, Index, MultiIndex -from pandera.core.pandas.container import DataFrameSchema - -data_types = [pd.DataFrame, pd.Series] - -if pandera.typing.dask.DASK_INSTALLED: - import dask.dataframe as dd - - data_types.extend([dd.DataFrame, dd.Series]) - -if pandera.typing.modin.MODIN_INSTALLED: - import modin.pandas as mpd - - data_types.extend([mpd.DataFrame, mpd.Series]) - -if pandera.typing.pyspark.PYSPARK_INSTALLED: - import pyspark.pandas as ps - - data_types.extend([ps.DataFrame, ps.Series]) - -if pandera.typing.geopandas.GEOPANDAS_INSTALLED: - import geopandas as gpd - - data_types.extend([gpd.GeoDataFrame, gpd.GeoSeries]) - -for t in data_types: - Check.register_backend(t, PandasCheckBackend) - Hypothesis.register_backend(t, PandasHypothesisBackend) diff --git a/pandera/core/base/checks.py b/pandera/core/base/checks.py index b3e0c5078..f534cb42c 100644 --- a/pandera/core/base/checks.py +++ b/pandera/core/base/checks.py @@ -1,8 +1,6 @@ """Data validation base check.""" -import inspect from collections import namedtuple -from functools import wraps from itertools import chain from typing import ( Any, @@ -16,8 +14,8 @@ Union, no_type_check, ) - import pandas as pd +from multimethod import multidispatch as _multidispatch from pandera.backends.base import BaseCheckBackend @@ -36,30 +34,44 @@ DataFrameCheckObj = Union[pd.DataFrame, Dict[str, pd.DataFrame]] -def register_check_statistics(statistics_args): - """Decorator to set statistics based on Check method.""" - - def register_check_statistics_decorator(class_method): - @wraps(class_method) - def _wrapper(cls, *args, **kwargs): - args = list(args) - arg_names = inspect.getfullargspec(class_method).args[1:] - if not arg_names: - arg_names = statistics_args - args_dict = {**dict(zip(arg_names, args)), **kwargs} - check = class_method(cls, *args, **kwargs) - check.statistics = { - stat: args_dict.get(stat) for stat in statistics_args - } - check.statistics_args = statistics_args - return check - - return _wrapper +_T = TypeVar("_T", bound="BaseCheck") - return register_check_statistics_decorator +# pylint: disable=invalid-name +class multidispatch(_multidispatch): + """ + Custom multidispatch class to handle copy, deepcopy, and code retrieval. + """ + + @property + def __code__(self): + """Retrieves the 'base' function of the multidispatch object.""" + assert ( + len(self) > 0 + ), f"multidispatch object {self} has no functions registered" + fn, *_ = [*self.values()] # type: ignore[misc] + return fn.__code__ + + def __reduce__(self): + """ + Handle custom pickling reduction method by initializing a new + multidispatch object, wrapped with the base function. + """ + state = self.__dict__ + # make sure all registered functions at time of pickling are captured + state["__registered_functions__"] = [*self.values()] + return ( + multidispatch, # object creation function + (state["__wrapped__"],), # arguments to said function + state, # arguments to `__setstate__` after creation + ) -_T = TypeVar("_T", bound="BaseCheck") + def __setstate__(self, state): + """Custom unpickling logic.""" + self.__dict__ = state + # rehydrate the multidispatch object with unpickled registered functions + for fn in state["__registered_functions__"]: + self.register(fn) class MetaCheck(type): # pragma: no cover @@ -68,15 +80,19 @@ class MetaCheck(type): # pragma: no cover BACKEND_REGISTRY: Dict[ Tuple[Type, Type], Type[BaseCheckBackend] ] = {} # noqa + """Registry of check backends implemented for specific data objects.""" + CHECK_FUNCTION_REGISTRY: Dict[str, Callable] = {} # noqa - CHECK_REGISTRY: Dict[str, Callable] = {} # noqa + """Built-in check function registry.""" + REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa + """User-defined custom checks.""" def __getattr__(cls, name: str) -> Any: """Prevent attribute errors for registered checks.""" attr = { **cls.__dict__, - **cls.CHECK_REGISTRY, + **cls.CHECK_FUNCTION_REGISTRY, **cls.REGISTERED_CUSTOM_CHECKS, }.get(name) if attr is None: @@ -91,7 +107,7 @@ def __dir__(cls) -> Iterable[str]: """Allow custom checks to show up as attributes when autocompleting.""" return chain( super().__dir__(), - cls.CHECK_REGISTRY.keys(), + cls.CHECK_FUNCTION_REGISTRY.keys(), cls.REGISTERED_CUSTOM_CHECKS.keys(), ) @@ -117,9 +133,47 @@ def __init__( self, name: Optional[str] = None, error: Optional[str] = None, + statistics: Optional[Dict[str, Any]] = None, ): self.name = name self.error = error + self.statistics = statistics + + @classmethod + def register_builtin_check_fn(cls, fn: Callable): + """Registers a built-in check function""" + cls.CHECK_FUNCTION_REGISTRY[fn.__name__] = multidispatch(fn) + return fn + + @classmethod + def get_builtin_check_fn(cls, name: str): + """Gets a built-in check function""" + return cls.CHECK_FUNCTION_REGISTRY[name] + + @classmethod + def from_builtin_check_name( + cls, + name: str, + init_kwargs, + error: Union[str, Callable], + statistics: Dict[str, Any] = None, + **check_kwargs, + ): + """Create a Check object from a built-in check's name.""" + kws = {**init_kwargs, **check_kwargs} + if "error" not in kws: + kws["error"] = error + + # statistics are the raw check constraint values that are untransformed + # by the check object + if statistics is None: + statistics = check_kwargs + + return cls( + cls.get_builtin_check_fn(name), + statistics=statistics, + **kws, + ) @classmethod def register_backend(cls, type_: Type, backend: Type[BaseCheckBackend]): diff --git a/pandera/core/checks.py b/pandera/core/checks.py index 586574ef0..932fa87a1 100644 --- a/pandera/core/checks.py +++ b/pandera/core/checks.py @@ -1,6 +1,16 @@ """Data validation check definition.""" -from typing import Any, Callable, Dict, List, Optional, Union +import re +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + TypeVar, + Union, +) import pandas as pd @@ -9,8 +19,12 @@ from pandera.strategies import SearchStrategy +T = TypeVar("T") + + +# pylint: disable=too-many-public-methods class Check(BaseCheck): - """Check a pandas Series or DataFrame for certain properties.""" + """Check a data object for certain properties.""" def __init__( self, @@ -29,7 +43,7 @@ def __init__( strategy: Optional[SearchStrategy] = None, **check_kwargs, ) -> None: - """Apply a validation function to each element, Series, or DataFrame. + """Apply a validation function to a data object. :param check_fn: A function to check pandas data structure. For Column or SeriesSchema checks, if element_wise is True, this function @@ -214,3 +228,377 @@ def __call__( """ backend = self.get_backend(check_obj)(self) return backend(check_obj, column) + + @classmethod + def equal_to(cls, value: Any, **kwargs) -> "Check": + """Ensure all elements of a data container equal a certain value. + + :param value: values in this pandas data structure must be + equal to this value. + """ + return cls.from_builtin_check_name( + "equal_to", + kwargs, + error=f"equal_to({value})", + value=value, + ) + + @classmethod + def not_equal_to(cls, value: Any, **kwargs) -> "Check": + """Ensure no elements of a data container equals a certain value. + + :param value: This value must not occur in the checked + :class:`pandas.Series`. + """ + return cls.from_builtin_check_name( + "not_equal_to", + kwargs, + error=f"not_equal_to({value})", + value=value, + ) + + @classmethod + def greater_than(cls, min_value: Any, **kwargs) -> "Check": + """ + Ensure values of a data container are strictly greater than a minimum + value. + + :param min_value: Lower bound to be exceeded. Must be a type comparable + to the dtype of the :class:`pandas.Series` to be validated (e.g. a + numerical type for float or int and a datetime for datetime). + """ + if min_value is None: + raise ValueError("min_value must not be None") + return cls.from_builtin_check_name( + "greater_than", + kwargs, + error=f"greater_than({min_value})", + min_value=min_value, + ) + + @classmethod + def greater_than_or_equal_to(cls, min_value: Any, **kwargs) -> "Check": + """Ensure all values are greater or equal a certain value. + + :param min_value: Allowed minimum value for values of a series. Must be + a type comparable to the dtype of the :class:`pandas.Series` to be + validated. + """ + if min_value is None: + raise ValueError("min_value must not be None") + return cls.from_builtin_check_name( + "greater_than_or_equal_to", + kwargs, + error=f"greater_than_or_equal_to({min_value})", + min_value=min_value, + ) + + @classmethod + def less_than(cls, max_value: Any, **kwargs) -> "Check": + """Ensure values of a series are strictly below a maximum value. + + :param max_value: All elements of a series must be strictly smaller + than this. Must be a type comparable to the dtype of the + :class:`pandas.Series` to be validated. + """ + if max_value is None: + raise ValueError("max_value must not be None") + return cls.from_builtin_check_name( + "less_than", + kwargs, + error=f"less_than({max_value})", + max_value=max_value, + ) + + @classmethod + def less_than_or_equal_to(cls, max_value: Any, **kwargs) -> "Check": + """Ensure values of a series are strictly below a maximum value. + + :param max_value: Upper bound not to be exceeded. Must be a type + comparable to the dtype of the :class:`pandas.Series` to be + validated. + """ + if max_value is None: + raise ValueError("max_value must not be None") + return cls.from_builtin_check_name( + "less_than_or_equal_to", + kwargs, + error=f"less_than_or_equal_to({max_value})", + max_value=max_value, + ) + + @classmethod + def in_range( + cls, + min_value: T, + max_value: T, + include_min: bool = True, + include_max: bool = True, + **kwargs, + ) -> "Check": + """Ensure all values of a series are within an interval. + + Both endpoints must be a type comparable to the dtype of the + data object to be validated. + + :param min_value: Left / lower endpoint of the interval. + :param max_value: Right / upper endpoint of the interval. Must not be + smaller than min_value. + :param include_min: Defines whether min_value is also an allowed value + (the default) or whether all values must be strictly greater than + min_value. + :param include_max: Defines whether min_value is also an allowed value + (the default) or whether all values must be strictly smaller than + max_value. + """ + if min_value is None: + raise ValueError("min_value must not be None") + if max_value is None: + raise ValueError("max_value must not be None") + if max_value < min_value or ( # type: ignore + min_value == max_value and (not include_min or not include_max) + ): + raise ValueError( + f"The combination of min_value = {min_value} and " + f"max_value = {max_value} defines an empty interval!" + ) + return cls.from_builtin_check_name( + "in_range", + kwargs, + error=f"in_range({min_value}, {max_value})", + min_value=min_value, + max_value=max_value, + include_min=include_min, + include_max=include_max, + ) + + @classmethod + def isin(cls, allowed_values: Iterable, **kwargs) -> "Check": + """Ensure only allowed values occur within a series. + + This checks whether all elements of a data object + are part of the set of elements of allowed values. If allowed + values is a string, the set of elements consists of all distinct + characters of the string. Thus only single characters which occur + in allowed_values at least once can meet this condition. If you + want to check for substrings use :meth:`Check.str_contains`. + + :param allowed_values: The set of allowed values. May be any iterable. + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + try: + allowed_values_mod = frozenset(allowed_values) + except TypeError as exc: + raise ValueError( + f"Argument allowed_values must be iterable. Got {allowed_values}" + ) from exc + return cls.from_builtin_check_name( + "isin", + kwargs, + error=f"isin({allowed_values})", + statistics={"allowed_values": allowed_values}, + allowed_values=allowed_values_mod, + ) + + @classmethod + def notin(cls, forbidden_values: Iterable, **kwargs) -> "Check": + """Ensure some defined values don't occur within a series. + + Like :meth:`Check.isin` this check operates on single characters if + it is applied on strings. If forbidden_values is a string, it is + understood as set of prohibited characters. Any string of length > 1 + can't be in it by design. + + :param forbidden_values: The set of values which should not occur. May + be any iterable. + :param raise_warning: if True, check raises UserWarning instead of + SchemaError on validation. + """ + try: + forbidden_values_mod = frozenset(forbidden_values) + except TypeError as exc: + raise ValueError( + "Argument forbidden_values must be iterable. " + f"Got {forbidden_values}" + ) from exc + return cls.from_builtin_check_name( + "notin", + kwargs, + error=f"notin({forbidden_values})", + statistics={"forbidden_values": forbidden_values}, + forbidden_values=forbidden_values_mod, + ) + + @classmethod + def str_matches(cls, pattern: Union[str, re.Pattern], **kwargs) -> "Check": + """Ensure that string values match a regular expression. + + :param pattern: Regular expression pattern to use for matching + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + try: + pattern_mod = re.compile(pattern) + except TypeError as exc: + raise ValueError( + f'pattern="{pattern}" cannot be compiled as regular expression' + ) from exc + return cls.from_builtin_check_name( + "str_matches", + kwargs, + error=f"str_matches('{pattern}')", + statistics={"pattern": pattern}, + pattern=pattern_mod, + ) + + @classmethod + def str_contains( + cls, pattern: Union[str, re.Pattern], **kwargs + ) -> "Check": + """Ensure that a pattern can be found within each row. + + :param pattern: Regular expression pattern to use for searching + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + try: + pattern_mod = re.compile(pattern) + except TypeError as exc: + raise ValueError( + f'pattern="{pattern}" cannot be compiled as regular expression' + ) from exc + return cls.from_builtin_check_name( + "str_contains", + kwargs, + error=f"str_contains('{pattern}')", + statistics={"pattern": pattern}, + pattern=pattern_mod, + ) + + @classmethod + def str_startswith(cls, string: str, **kwargs) -> "Check": + """Ensure that all values start with a certain string. + + :param string: String all values should start with + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + return cls.from_builtin_check_name( + "str_startswith", + kwargs, + error=f"str_startswith('{string}')", + string=string, + ) + + @classmethod + def str_endswith(cls, string: str, **kwargs) -> "Check": + """Ensure that all values end with a certain string. + + :param string: String all values should end with + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + return cls.from_builtin_check_name( + "str_endswith", + kwargs, + error=f"str_endswith('{string}')", + string=string, + ) + + @classmethod + def str_length( + cls, + min_value: int = None, + max_value: int = None, + **kwargs, + ) -> "Check": + """Ensure that the length of strings is within a specified range. + + :param min_value: Minimum length of strings (default: no minimum) + :param max_value: Maximum length of strings (default: no maximum) + """ + if min_value is None and max_value is None: + raise ValueError( + "At least a minimum or a maximum need to be specified. Got " + "None." + ) + return cls.from_builtin_check_name( + "str_length", + kwargs, + error=f"str_length({min_value}, {max_value})", + min_value=min_value, + max_value=max_value, + ) + + @classmethod + def unique_values_eq(cls, values: str, **kwargs) -> "Check": + """Ensure that unique values in the data object contain all values. + + .. note:: + In constrast with :func:`isin`, this check makes sure that all the + items in the ``values`` iterable are contained within the series. + + :param values: The set of values that must be present. Maybe any iterable. + """ + try: + values_mod = frozenset(values) + except TypeError as exc: + raise ValueError( + f"Argument values must be iterable. Got {values}" + ) from exc + return cls.from_builtin_check_name( + "unique_values_eq", + kwargs, + error=f"unique_values_eq({values})", + statistics={"values": values_mod}, + values=values_mod, + ) + + # Aliases + # ------- + + @classmethod + def eq(cls, value: Any, **kwargs) -> "Check": + """Alias of :meth:`~pandera.core.checks.Check.equal_to`""" + return cls.equal_to(value, **kwargs) + + @classmethod + def ne(cls, value: Any, **kwargs) -> "Check": + """Alias of :meth:`~pandera.core.checks.Check.not_equal_to`""" + return cls.not_equal_to(value, **kwargs) + + @classmethod + def gt(cls, min_value: Any, **kwargs) -> "Check": + """Alias of :meth:`~pandera.core.checks.Check.greater_than`""" + return cls.greater_than(min_value, **kwargs) + + @classmethod + def ge(cls, min_value: Any, **kwargs) -> "Check": + """ + Alias of :meth:`~pandera.core.checks.Check.greater_than_or_equal_to` + """ + return cls.greater_than_or_equal_to(min_value, **kwargs) + + @classmethod + def lt(cls, max_value: Any, **kwargs) -> "Check": + """Alias of :meth:`~pandera.core.checks.Check.less_than`""" + return cls.less_than(max_value, **kwargs) + + @classmethod + def le(cls, max_value: Any, **kwargs) -> "Check": + """Alias of :meth:`~pandera.core.checks.Check.less_than_or_equal_to`""" + return cls.less_than_or_equal_to(max_value, **kwargs) + + @classmethod + def between( + cls, + min_value: T, + max_value: T, + include_min: bool = True, + include_max: bool = True, + **kwargs, + ) -> "Check": + """Alias of :meth:`~pandera.core.checks.Check.in_range`""" + return cls.in_range( + min_value, + max_value, + include_min, + include_max, + **kwargs, + ) diff --git a/pandera/core/extensions.py b/pandera/core/extensions.py index f440971b3..0d18d55e2 100644 --- a/pandera/core/extensions.py +++ b/pandera/core/extensions.py @@ -1,28 +1,33 @@ """Extensions module.""" +import inspect import warnings from enum import Enum from functools import partial, wraps -from inspect import signature, Parameter, Signature, _empty -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from inspect import signature +from typing import Callable, List, Optional, Tuple, Type, Union import pandas as pd -from multimethod import multidispatch +import typing_inspect -from pandera.core.base.checks import register_check_statistics from pandera.core.checks import Check from pandera.core.hypotheses import Hypothesis +from pandera.strategies.base_strategies import STRATEGY_DISPATCHER + + +class BuiltinCheckRegistrationError(Exception): + """ + Exception raised when registering a built-in check implementation but the + default check function implementation hasn't been registered with + :py:meth:`~flytekit.core.base.BaseCheck.register_builtin_check_fn`. + """ # pylint: disable=too-many-locals -def register_check( +def register_builtin_check( fn=None, - pre_init_hook: Optional[Callable] = None, - aliases: Optional[List[str]] = None, strategy: Optional[Callable] = None, - error: Optional[Union[str, Callable]] = None, - check_cls: Type = Check, - samples_kwtypes: Optional[Dict[str, Type]] = None, + _check_cls: Type = Check, **outer_kwargs, ): """Register a check method to the Check namespace. @@ -33,249 +38,62 @@ def register_check( if fn is None: return partial( - register_check, - pre_init_hook=pre_init_hook, - aliases=aliases, + register_builtin_check, strategy=strategy, - error=error, - check_cls=check_cls, - samples_kwtypes=samples_kwtypes, + _check_cls=_check_cls, **outer_kwargs, ) name = fn.__name__ # see if the check function is already registered - check_fn = check_cls.CHECK_FUNCTION_REGISTRY.get(name) - + check_fn = _check_cls.CHECK_FUNCTION_REGISTRY.get(name) fn_sig = signature(fn) - # this is a special case for handling hypotheses, since the sample keys - # need to be treated like statistics and is used during preprocessing, not - # in the check function itself. - if samples_kwtypes is None: - samples_args = [] - samples_params = [] - else: - samples_args = [*samples_kwtypes] - samples_params = [ - Parameter( - name, - Parameter.POSITIONAL_OR_KEYWORD, - annotation=samples_kwtypes[name], - ) - for name in samples_kwtypes - ] - - # derive statistics from function arguments after the 0th positional arg - statistics = [*samples_args, *[*fn_sig.parameters.keys()][1:]] - statistics_params = [ - *samples_params, - *[*fn_sig.parameters.values()][1:], - ] - statistics_defaults = { - p.name: p.default - for p in fn_sig.parameters.values() - if p.default is not _empty - } - - if check_fn is None: + # register the check strategy for this particular check, identified + # by the check `name`, and the data type of the check function. This + # supports Union types. Also assume that the data type of the data + # object to validate is the first argument. + data_type = [*fn_sig.parameters.values()][0].annotation - dispatch_check_fn = multidispatch(fn) + if typing_inspect.get_origin(data_type) is Tuple: + data_type, *_ = typing_inspect.get_args(data_type) - # create proxy function so we can modify the signature and docstring - # of the method to reflect correctly in the documentation - # pylint: disable=unused-argument - def check_function_proxy(cls, *args, **kws): - return dispatch_check_fn(*args, **kws) - - update_check_fn_proxy( - check_cls, check_function_proxy, fn, fn_sig, statistics_params + if typing_inspect.get_origin(data_type) is Union: + data_types = typing_inspect.get_args(data_type) + else: + data_types = (data_type,) + + if strategy is not None: + for dt in data_types: + STRATEGY_DISPATCHER[(name, dt)] = strategy + + if check_fn is None: # pragma: no cover + raise BuiltinCheckRegistrationError( + f"Check '{name}' doesn't have a base check implementation. " + f"You need to create a stub method in the {_check_cls} class and " + "then register a base check function implementation with the " + f"{_check_cls}.register_builtin_check_fn method.\n" + "See the `pandera.core.base.builtin_checks` and " + "`pandera.backends.pandas.builtin_checks` modules as an example." ) - @wraps(check_function_proxy) - def check_method(cls, *args, **check_kwargs): - args = list(args) - - statistics_kwargs = dict(zip(statistics, args)) - for stat in statistics: - if stat in check_kwargs: - statistics_kwargs[stat] = check_kwargs.pop(stat) - elif stat not in statistics_kwargs: - statistics_kwargs[stat] = statistics_defaults.get( - stat, None - ) - - _check_kwargs = { - "error": ( - error.format(**statistics_kwargs) - if isinstance(error, str) - else error(**statistics_kwargs) - ) - } - _check_kwargs.update(outer_kwargs) - _check_kwargs.update(check_kwargs) - - # this is a special case for handling hypotheses, since the sample - # keys need to be treated like statistics and is used during - # preprocessing, not in the check function itself. - if samples_kwtypes is not None: - samples = [] - for sample_arg in samples_kwtypes: - samples.append(statistics_kwargs.pop(sample_arg)) - _check_kwargs["samples"] = samples - - # This is kind of ugly... basically we're creating another - # stats kwargs variable that's actually used when invoking the check - # function (which may or may not be processed by pre_init_hook) - # This is so that the original value is not modified by - # pre_init_hook when, for e.g. the check is serialized with the io - # module. Figure out a better way to do this! - check_fn_stats_kwargs = ( - pre_init_hook(statistics_kwargs) - if pre_init_hook is not None - else statistics_kwargs - ) - - # internal wrapper is needed here to make sure the inner check_fn - # produced by this method is consistent with the registered check - # function - if check_cls is Check: - - @wraps(fn) - def _check_fn(check_obj, **inner_kwargs): - """ - inner_kwargs will be based in via Check.__init__ kwargs. - """ - # Raise more informative error when this fails to dispatch - return check_function_proxy( - cls, - check_obj, - *check_fn_stats_kwargs.values(), - **inner_kwargs, - ) - - elif check_cls is Hypothesis: - - @wraps(fn) - def _check_fn(*samples, **inner_kwargs): - """ - inner_kwargs will be based in via Check.__init__ kwargs. - """ - # Raise more informative error when this fails to dispatch - return check_function_proxy( - cls, - *samples, - **{ - **check_fn_stats_kwargs, - **inner_kwargs, - }, - ) - - else: - raise TypeError(f"check_cls {check_cls} not recognized") - - return cls( - _check_fn, - statistics=statistics_kwargs, - strategy=( - None - if strategy is None - else partial(strategy, **statistics_kwargs) - ), - **_check_kwargs, - ) - - check_cls.CHECK_FUNCTION_REGISTRY[name] = dispatch_check_fn - setattr(check_cls, name, classmethod(check_method)) - - class_check_method = getattr(check_cls, name) + check_fn.register(fn) # type: ignore - for _name in [] if aliases is None else aliases: - setattr(check_cls, _name, class_check_method) - else: - check_fn.register(fn) # type: ignore - - return getattr(check_cls, name) + return fn -def register_hypothesis(samples_kwtypes=None, **kwargs): +def register_builtin_hypothesis(**kwargs): """Register a new hypothesis.""" return partial( - register_check, - check_cls=Hypothesis, - samples_kwtypes=samples_kwtypes, + register_builtin_check, + _check_cls=Hypothesis, **kwargs, ) -def generate_check_signature( - check_cls: Type, - sig: Signature, - statistics_params: List[Parameter], -) -> Signature: - """Generates a check signature from check statistics.""" - # assume the first argument is the check object - return sig.replace( - parameters=[ - # This first parameter will be ignored since it's the check object - Parameter("_", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("cls", Parameter.POSITIONAL_OR_KEYWORD), - *statistics_params, - Parameter( - "kwargs", Parameter.VAR_KEYWORD, annotation=Dict[str, Any] - ), - ], - return_annotation=check_cls, - ) - - -def generate_check_annotations( - check_cls: Type, - statistics_params: List[Parameter], -) -> Dict[str, Type]: - """Generates a check type annotations from check statistics.""" - return { - **{p.name: p.annotation for p in statistics_params}, - "kwargs": Dict[ - str, - Any, - ], - "return": check_cls, - } - - -def modify_check_fn_doc(doc: str) -> str: - """Adds""" - return ( - f"{doc}\n{' ' * 4}:param kwargs: arguments forwarded to the " - ":py:class:`~pandera.core.checks.Check` constructor." - ) - - -def update_check_fn_proxy( - check_cls: Type, check_function_proxy, fn, fn_sig, statistics_params -): - """ - Manually update the signature of `check_function` so that docstring matches - original function's signature, but includes ``**kwargs``, etc. - """ - check_function_proxy.__name__ = fn.__name__ - check_function_proxy.__module__ = fn.__module__ - check_function_proxy.__qualname__ = fn.__qualname__ - check_function_proxy.__signature__ = generate_check_signature( - check_cls, - fn_sig, - statistics_params, - ) - check_function_proxy.__doc__ = modify_check_fn_doc(fn.__doc__) - check_function_proxy.__annotations__ = generate_check_annotations( - check_cls, statistics_params - ) - - # -------------------------------- -# LEGACY CHECK REGISTRATION METHOD +# CUSTOM CHECK REGISTRATION METHOD # -------------------------------- # # The `register_check_method` decorator is the legacy method for registering @@ -291,6 +109,29 @@ class CheckType(Enum): GROUPBY = 3 #: Check applied to dictionary of Series or DataFrames. +def register_check_statistics(statistics_args): + """Decorator to set statistics based on Check method.""" + + def register_check_statistics_decorator(class_method): + @wraps(class_method) + def _wrapper(cls, *args, **kwargs): + args = list(args) + arg_names = inspect.getfullargspec(class_method).args[1:] + if not arg_names: + arg_names = statistics_args + args_dict = {**dict(zip(arg_names, args)), **kwargs} + check = class_method(cls, *args, **kwargs) + check.statistics = { + stat: args_dict.get(stat) for stat in statistics_args + } + check.statistics_args = statistics_args + return check + + return _wrapper + + return register_check_statistics_decorator + + def register_check_method( check_fn=None, *, diff --git a/pandera/core/hypotheses.py b/pandera/core/hypotheses.py index 5ad8a620b..a2e83de12 100644 --- a/pandera/core/hypotheses.py +++ b/pandera/core/hypotheses.py @@ -1,6 +1,5 @@ """Data validation checks for hypothesis testing.""" -from functools import partial, update_wrapper from typing import Any, Callable, Dict, List, Optional, Union from pandera import errors @@ -8,6 +7,9 @@ from pandera.strategies import SearchStrategy +DEFAULT_ALPHA = 0.01 + + class Hypothesis(Check): """Special type of :class:`Check` that defines hypothesis tests on data.""" @@ -146,8 +148,8 @@ def __init__( f"The relationship {relationship} isn't a built in method" ) - self.test = partial(test, **{} if test_kwargs is None else test_kwargs) - update_wrapper(self.test, test) + check_kwargs = test_kwargs if test_kwargs is not None else check_kwargs + self.test = test self.relationship = relationship relationship_kwargs = relationship_kwargs or {} @@ -174,3 +176,205 @@ def __init__( strategy=strategy, **check_kwargs, ) + + @classmethod + def two_sample_ttest( + cls, + sample1: str, + sample2: str, + groupby: Optional[Union[str, List[str], Callable]] = None, + relationship: str = "equal", + alpha: float = DEFAULT_ALPHA, + equal_var: bool = True, + nan_policy: str = "propagate", + **kwargs, + ) -> "Hypothesis": + """Calculate a t-test for the means of two samples. + + Perform a two-sided test for the null hypothesis that 2 independent + samples have identical average (expected) values. This test assumes + that the populations have identical variances by default. + + :param sample1: The first sample group to test. For `Column` and + `SeriesSchema` hypotheses, refers to the level in the `groupby` + column. For `DataFrameSchema` hypotheses, refers to column in + the `DataFrame`. + :param sample2: The second sample group to test. For `Column` and + `SeriesSchema` hypotheses, refers to the level in the `groupby` + column. For `DataFrameSchema` hypotheses, refers to column in + the `DataFrame`. + :param groupby: If a string or list of strings is provided, then + these columns are used to group the Column Series by `groupby`. + If a callable is passed, the expected signature is + DataFrame -> DataFrameGroupby. The function has access to the + entire dataframe, but the Column.name is selected from this + DataFrameGroupby object so that a SeriesGroupBy object is passed + into `fn`. + + Specifying this argument changes the `fn` signature to: + dict[str|tuple[str], Series] -> bool|pd.Series[bool] + + Where specific groups can be obtained from the input dict. + :param relationship: Represents what relationship conditions are + imposed on the hypothesis test. Available relationships + are: "greater_than", "less_than", "not_equal", and "equal". + For example, `group1 greater_than group2` specifies an alternative + hypothesis that the mean of group1 is greater than group 2 relative + to a null hypothesis that they are equal. + :param alpha: (Default value = 0.01) The significance level; the + probability of rejecting the null hypothesis when it is true. For + example, a significance level of 0.01 indicates a 1% risk of + concluding that a difference exists when there is no actual + difference. + :param equal_var: (Default value = True) If True (default), perform a + standard independent 2 sample test that assumes equal population + variances. If False, perform Welch's t-test, which does not + assume equal population variance + :param nan_policy: Defines how to handle when input returns nan, one of + {'propagate', 'raise', 'omit'}, (Default value = 'propagate'). + For more details see: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_ind.html + + :example: + + The the built-in class method to do a two-sample t-test. + + >>> import pandas as pd + >>> import pandera as pa + >>> + >>> + >>> schema = pa.DataFrameSchema({ + ... "height_in_feet": pa.Column( + ... float, [ + ... pa.Hypothesis.two_sample_ttest( + ... sample1="A", + ... sample2="B", + ... groupby="group", + ... relationship="greater_than", + ... alpha=0.05, + ... equal_var=True), + ... ]), + ... "group": pa.Column(str) + ... }) + >>> df = ( + ... pd.DataFrame({ + ... "height_in_feet": [8.1, 7, 5.2, 5.1, 4], + ... "group": ["A", "A", "B", "B", "B"] + ... }) + ... ) + >>> schema.validate(df)[["height_in_feet", "group"]] + height_in_feet group + 0 8.1 A + 1 7.0 A + 2 5.2 B + 3 5.1 B + 4 4.0 B + + """ + init_kwargs = { + "samples": [sample1, sample2], + "groupby": groupby, + "relationship": relationship, + "alpha": alpha, + } + init_kwargs.update(kwargs) + return cls.from_builtin_check_name( + "two_sample_ttest", + init_kwargs, + error=( + f"failed two sample ttest between '{sample1}' and '{sample2}'" + ), + equal_var=equal_var, + nan_policy=nan_policy, + ) + + @classmethod + def one_sample_ttest( + cls, + popmean: float, + sample: Optional[str] = None, + groupby: Optional[Union[str, List[str], Callable]] = None, + relationship: str = "equal", + alpha: float = DEFAULT_ALPHA, + nan_policy="propagate", + **kwargs, + ) -> "Hypothesis": + """Calculate a t-test for the mean of one sample. + + :param sample: The sample group to test. For `Column` and + `SeriesSchema` hypotheses, this refers to the `groupby` level that + is used to subset the `Column` being checked. For `DataFrameSchema` + hypotheses, refers to column in the `DataFrame`. + :param groupby: If a string or list of strings is provided, then these + columns are used to group the Column Series by `groupby`. If a + callable is passed, the expected signature is + DataFrame -> DataFrameGroupby. The function has access to the + entire dataframe, but the Column.name is selected from this + DataFrameGroupby object so that a SeriesGroupBy object is passed + into `fn`. + + Specifying this argument changes the `fn` signature to: + dict[str|tuple[str], Series] -> bool|pd.Series[bool] + + Where specific groups can be obtained from the input dict. + :param popmean: population mean to compare `sample` to. + :param relationship: Represents what relationship conditions are + imposed on the hypothesis test. Available relationships + are: "greater_than", "less_than", "not_equal" and "equal". For + example, `group1 greater_than group2` specifies an alternative + hypothesis that the mean of group1 is greater than group 2 relative + to a null hypothesis that they are equal. + :param alpha: (Default value = 0.01) The significance level; the + probability of rejecting the null hypothesis when it is true. For + example, a significance level of 0.01 indicates a 1% risk of + concluding that a difference exists when there is no actual + difference. + :param raise_warning: if True, check raises UserWarning instead of + SchemaError on validation. + + :example: + + If you want to compare one sample with a pre-defined mean: + + >>> import pandas as pd + >>> import pandera as pa + >>> + >>> + >>> schema = pa.DataFrameSchema({ + ... "height_in_feet": pa.Column( + ... float, [ + ... pa.Hypothesis.one_sample_ttest( + ... popmean=5, + ... relationship="greater_than", + ... alpha=0.1), + ... ]), + ... }) + >>> df = ( + ... pd.DataFrame({ + ... "height_in_feet": [8.1, 7, 6.5, 6.7, 5.1], + ... }) + ... ) + >>> schema.validate(df) + height_in_feet + 0 8.1 + 1 7.0 + 2 6.5 + 3 6.7 + 4 5.1 + + + """ + init_kwargs = { + "samples": sample, + "groupby": groupby, + "relationship": relationship, + "alpha": alpha, + } + init_kwargs.update(kwargs) + return cls.from_builtin_check_name( + "one_sample_ttest", + init_kwargs, + error=f"failed one sample ttest for column '{sample}'", + nan_policy=nan_policy, + popmean=popmean, + ) diff --git a/pandera/core/pandas/__init__.py b/pandera/core/pandas/__init__.py index 7a4f08705..3f4643a30 100644 --- a/pandera/core/pandas/__init__.py +++ b/pandera/core/pandas/__init__.py @@ -1,6 +1,5 @@ """Pandas core.""" -from pandera.core.pandas import checks from pandera.core.pandas.array import SeriesSchema from pandera.core.pandas.components import Column, Index, MultiIndex from pandera.core.pandas.container import DataFrameSchema diff --git a/pandera/core/pandas/hypotheses.py b/pandera/core/pandas/hypotheses.py deleted file mode 100644 index 78609fc22..000000000 --- a/pandera/core/pandas/hypotheses.py +++ /dev/null @@ -1,192 +0,0 @@ -"""Pandas implementation of built-in hypotheses.""" - -from typing import Tuple - -from pandera.backends.pandas.hypotheses import HAS_SCIPY -from pandera.core.extensions import register_hypothesis -from pandera.core.pandas.checks import PandasData - - -if HAS_SCIPY: - from scipy import stats - - -@register_hypothesis( - error="failed two sample ttest between '{sample1}' and '{sample2}'", - samples_kwtypes={"sample1": str, "sample2": str}, -) -def two_sample_ttest( - *samples: Tuple[PandasData, ...], - equal_var: bool = True, - nan_policy: str = "propagate", -) -> Tuple[float, float]: - """Calculate a t-test for the means of two samples. - - Perform a two-sided test for the null hypothesis that 2 independent - samples have identical average (expected) values. This test assumes - that the populations have identical variances by default. - - :param sample1: The first sample group to test. For `Column` and - `SeriesSchema` hypotheses, refers to the level in the `groupby` - column. For `DataFrameSchema` hypotheses, refers to column in - the `DataFrame`. - :param sample2: The second sample group to test. For `Column` and - `SeriesSchema` hypotheses, refers to the level in the `groupby` - column. For `DataFrameSchema` hypotheses, refers to column in - the `DataFrame`. - :param groupby: If a string or list of strings is provided, then these - columns are used to group the Column Series by `groupby`. If a - callable is passed, the expected signature is - DataFrame -> DataFrameGroupby. The function has access to the - entire dataframe, but the Column.name is selected from this - DataFrameGroupby object so that a SeriesGroupBy object is passed - into `fn`. - - Specifying this argument changes the `fn` signature to: - dict[str|tuple[str], Series] -> bool|pd.Series[bool] - - Where specific groups can be obtained from the input dict. - :param relationship: Represents what relationship conditions are - imposed on the hypothesis test. Available relationships - are: "greater_than", "less_than", "not_equal", and "equal". - For example, `group1 greater_than group2` specifies an alternative - hypothesis that the mean of group1 is greater than group 2 relative - to a null hypothesis that they are equal. - :param alpha: (Default value = 0.01) The significance level; the - probability of rejecting the null hypothesis when it is true. For - example, a significance level of 0.01 indicates a 1% risk of - concluding that a difference exists when there is no actual - difference. - :param equal_var: (Default value = True) If True (default), perform a - standard independent 2 sample test that assumes equal population - variances. If False, perform Welch's t-test, which does not - assume equal population variance - :param nan_policy: Defines how to handle when input returns nan, one of - {'propagate', 'raise', 'omit'}, (Default value = 'propagate'). - For more details see: - https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_ind.html - - :example: - - The the built-in class method to do a two-sample t-test. - - >>> import pandas as pd - >>> import pandera as pa - >>> - >>> - >>> schema = pa.DataFrameSchema({ - ... "height_in_feet": pa.Column( - ... float, [ - ... pa.Hypothesis.two_sample_ttest( - ... sample1="A", - ... sample2="B", - ... groupby="group", - ... relationship="greater_than", - ... alpha=0.05, - ... equal_var=True), - ... ]), - ... "group": pa.Column(str) - ... }) - >>> df = ( - ... pd.DataFrame({ - ... "height_in_feet": [8.1, 7, 5.2, 5.1, 4], - ... "group": ["A", "A", "B", "B", "B"] - ... }) - ... ) - >>> schema.validate(df)[["height_in_feet", "group"]] - height_in_feet group - 0 8.1 A - 1 7.0 A - 2 5.2 B - 3 5.1 B - 4 4.0 B - - """ - assert ( - len(samples) == 2 - ), "Expected two sample ttest data to contain exactly two samples" - return stats.ttest_ind( - samples[0], - samples[1], - equal_var=equal_var, - nan_policy=nan_policy, - ) - - -@register_hypothesis( - error="failed one sample ttest for column '{sample}'", - samples_kwtypes={"sample": str}, -) -def one_sample_ttest( - *samples: Tuple[PandasData, ...], - popmean: float, -) -> Tuple[float, float]: - """Calculate a t-test for the mean of one sample. - - :param sample: The sample group to test. For `Column` and - `SeriesSchema` hypotheses, this refers to the `groupby` level that - is used to subset the `Column` being checked. For `DataFrameSchema` - hypotheses, refers to column in the `DataFrame`. - :param groupby: If a string or list of strings is provided, then these - columns are used to group the Column Series by `groupby`. If a - callable is passed, the expected signature is - DataFrame -> DataFrameGroupby. The function has access to the - entire dataframe, but the Column.name is selected from this - DataFrameGroupby object so that a SeriesGroupBy object is passed - into `fn`. - - Specifying this argument changes the `fn` signature to: - dict[str|tuple[str], Series] -> bool|pd.Series[bool] - - Where specific groups can be obtained from the input dict. - :param popmean: population mean to compare `sample` to. - :param relationship: Represents what relationship conditions are - imposed on the hypothesis test. Available relationships - are: "greater_than", "less_than", "not_equal" and "equal". For - example, `group1 greater_than group2` specifies an alternative - hypothesis that the mean of group1 is greater than group 2 relative - to a null hypothesis that they are equal. - :param alpha: (Default value = 0.01) The significance level; the - probability of rejecting the null hypothesis when it is true. For - example, a significance level of 0.01 indicates a 1% risk of - concluding that a difference exists when there is no actual - difference. - :param raise_warning: if True, check raises UserWarning instead of - SchemaError on validation. - - :example: - - If you want to compare one sample with a pre-defined mean: - - >>> import pandas as pd - >>> import pandera as pa - >>> - >>> - >>> schema = pa.DataFrameSchema({ - ... "height_in_feet": pa.Column( - ... float, [ - ... pa.Hypothesis.one_sample_ttest( - ... popmean=5, - ... relationship="greater_than", - ... alpha=0.1), - ... ]), - ... }) - >>> df = ( - ... pd.DataFrame({ - ... "height_in_feet": [8.1, 7, 6.5, 6.7, 5.1], - ... }) - ... ) - >>> schema.validate(df) - height_in_feet - 0 8.1 - 1 7.0 - 2 6.5 - 3 6.7 - 4 5.1 - - - """ - assert ( - len(samples) == 1 - ), "Expected one sample ttest data to contain only one sample" - return stats.ttest_1samp(samples[0], popmean=popmean) diff --git a/pandera/decorators.py b/pandera/decorators.py index 869286dd0..904f0e101 100644 --- a/pandera/decorators.py +++ b/pandera/decorators.py @@ -20,7 +20,6 @@ overload, ) -import pandas as pd import wrapt from pydantic import validate_arguments @@ -82,7 +81,7 @@ def _handle_schema_error( decorator_name, fn: Callable, schema: Union[DataFrameSchema, SeriesSchema], - arg_df: pd.DataFrame, + data_obj: Any, schema_error: errors.SchemaError, ) -> NoReturn: """Reraise schema validation error with decorator context. @@ -95,7 +94,7 @@ def _handle_schema_error( checks. """ raise _parse_schema_error( - decorator_name, fn, schema, arg_df, schema_error + decorator_name, fn, schema, data_obj, schema_error ) from schema_error @@ -103,7 +102,7 @@ def _parse_schema_error( decorator_name, fn: Callable, schema: Union[DataFrameSchema, SeriesSchema], - arg_df: pd.DataFrame, + data_obj: Any, schema_error: errors.SchemaError, ) -> NoReturn: """Parse schema validation error with decorator context. @@ -121,7 +120,7 @@ def _parse_schema_error( msg = f"error in {decorator_name} decorator of function '{func_name}': {schema_error}" return errors.SchemaError( # type: ignore[misc] schema, - arg_df, + data_obj, msg, failure_cases=schema_error.failure_cases, check=schema_error.check, @@ -685,7 +684,9 @@ def _check_arg(arg_name: str, arg_value: Any) -> Any: if len(error_handler.collected_errors) == 1: raise error_handler.collected_errors[0]["error"] # type: ignore[misc] raise errors.SchemaErrors( - schema, error_handler.collected_errors, arg_value + schema=schema, + schema_errors=error_handler.collected_errors, + data=arg_value, ) sig = inspect.signature(wrapped) diff --git a/pandera/errors.py b/pandera/errors.py index 035ed50bd..3f9a21a94 100644 --- a/pandera/errors.py +++ b/pandera/errors.py @@ -1,21 +1,7 @@ """pandera-specific errors.""" import warnings -from collections import defaultdict, namedtuple -from typing import Any, Dict, List, Union - -import pandas as pd - -ErrorData = namedtuple( - "ErrorData", - [ - "data", - "error_counts", - "column_errors", - "type_errors", - "check_errors", - ], -) +from typing import Any, Dict, List, NamedTuple class ReducedPickleExceptionBase(Exception): @@ -25,7 +11,7 @@ class ReducedPickleExceptionBase(Exception): string via `TO_STRING_KEYS`. """ - TO_STRING_KEYS: List[str] + TO_STRING_KEYS: List[str] = [] def __reduce__(self): """Exception.__reduce__ is incompatible. Override with custom layout. @@ -133,6 +119,14 @@ class BaseStrategyOnlyError(Exception): """ +class FailureCaseMetadata(NamedTuple): + """Consolidated failure cases, summary message, and error counts.""" + + failure_cases: Any + message: str + error_counts: Dict[str, int] + + class SchemaErrors(ReducedPickleExceptionBase): """Raised when multiple schema are lazily collected into one error.""" @@ -146,152 +140,15 @@ def __init__( self, schema, schema_errors: List[Dict[str, Any]], - data: Union[pd.Series, pd.DataFrame, pd.Index, pd.MultiIndex], + data: Any, ): - error_counts, failure_cases = self._parse_schema_errors(schema_errors) self.schema = schema - super().__init__(self._message(error_counts, failure_cases)) self.schema_errors = schema_errors - self.error_counts = error_counts - self.failure_cases = failure_cases self.data = data - def _message(self, error_counts, schema_errors): - """Format error message.""" - msg = ( - f"Schema {self.schema.name}: A total of " - f"{sum(error_counts.values())} schema errors were found.\n" - ) - - msg += "\nError Counts" - msg += "\n------------\n" - for k, v in error_counts.items(): - msg += f"- {k}: {v}\n" - - def agg_failure_cases(df): - # Note: hack to support unhashable types, proper solution that only transforms - # when requires https://github.com/unionai-oss/pandera/issues/260 - df.failure_case = df.failure_case.astype(str) - # NOTE: this is a hack to add modin support - if type(df).__module__.startswith("modin.pandas"): - return ( - df.groupby(["schema_context", "column", "check"]) - .agg({"failure_case": "unique"}) - .failure_case - ) - return df.groupby( - ["schema_context", "column", "check"] - ).failure_case.unique() - - agg_schema_errors = ( - schema_errors.fillna({"column": ""}) - .pipe(agg_failure_cases) - .rename("failure_cases") - .to_frame() - .assign(n_failure_cases=lambda df: df.failure_cases.map(len)) - ) - index_labels = [ - agg_schema_errors.index.names.index(name) - for name in ["schema_context", "column"] - ] - agg_schema_errors = agg_schema_errors.sort_index( - level=index_labels, - ascending=[False, True], - ) - msg += "\nSchema Error Summary" - msg += "\n--------------------\n" - with pd.option_context("display.max_colwidth", 100): - msg += agg_schema_errors.to_string() - msg += SCHEMA_ERRORS_SUFFIX - return msg - - @staticmethod - def _parse_schema_errors(schema_errors: List[Dict[str, Any]]): - """Parse schema error dicts to produce data for error message.""" - error_counts = defaultdict(int) # type: ignore - check_failure_cases = [] - - column_order = [ - "schema_context", - "column", - "check", - "check_number", - "failure_case", - "index", - ] - - for schema_error_dict in schema_errors: - reason_code = schema_error_dict["reason_code"] - err = schema_error_dict["error"] - - error_counts[reason_code] += 1 - check_identifier = ( - None - if err.check is None - else err.check - if isinstance(err.check, str) - else err.check.error - if err.check.error is not None - else err.check.name - if err.check.name is not None - else str(err.check) - ) - - if err.failure_cases is not None: - if "column" in err.failure_cases: - column = err.failure_cases["column"] - else: - column = ( - err.schema.name - if reason_code == "schema_component_check" - else None - ) - - failure_cases = err.failure_cases.assign( - schema_context=err.schema.__class__.__name__, - check=check_identifier, - check_number=err.check_index, - # if the column key is a tuple (for MultiIndex column - # names), explicitly wrap `column` in a list of the - # same length as the number of failure cases. - column=( - [column] * err.failure_cases.shape[0] - if isinstance(column, tuple) - else column - ), - ) - check_failure_cases.append(failure_cases[column_order]) - - # NOTE: this is a hack to support pyspark.pandas and modin - concat_fn = pd.concat # type: ignore - if any( - type(x).__module__.startswith("pyspark.pandas") - for x in check_failure_cases - ): - # pylint: disable=import-outside-toplevel - import pyspark.pandas as ps - - concat_fn = ps.concat # type: ignore - check_failure_cases = [ - x if isinstance(x, ps.DataFrame) else ps.DataFrame(x) - for x in check_failure_cases - ] - elif any( - type(x).__module__.startswith("modin.pandas") - for x in check_failure_cases - ): - # pylint: disable=import-outside-toplevel - import modin.pandas as mpd - - concat_fn = mpd.concat # type: ignore - check_failure_cases = [ - x if isinstance(x, mpd.DataFrame) else mpd.DataFrame(x) - for x in check_failure_cases - ] - - failure_cases = ( - concat_fn(check_failure_cases) - .reset_index(drop=True) - .sort_values("schema_context", ascending=False) + failure_cases_metadata = schema.BACKEND.failure_cases_metadata( + schema.name, schema_errors ) - return error_counts, failure_cases + self.error_counts = failure_cases_metadata.error_counts + self.failure_cases = failure_cases_metadata.failure_cases + super().__init__(failure_cases_metadata.message) diff --git a/pandera/extensions.py b/pandera/extensions.py index a025920e0..b7802400e 100644 --- a/pandera/extensions.py +++ b/pandera/extensions.py @@ -2,12 +2,9 @@ # pylint: disable=unused-import from pandera.core.extensions import ( - register_check, - register_hypothesis, - generate_check_signature, - generate_check_annotations, - modify_check_fn_doc, - update_check_fn_proxy, + register_builtin_check, + register_builtin_hypothesis, CheckType, register_check_method, + register_check_statistics, ) diff --git a/pandera/strategies/base_strategies.py b/pandera/strategies/base_strategies.py new file mode 100644 index 000000000..9004beb94 --- /dev/null +++ b/pandera/strategies/base_strategies.py @@ -0,0 +1,8 @@ +"""Base module for `hypothesis`-based strategies for data synthesis.""" + +from typing import Callable, Dict, Tuple, Type + + +# This strategy registry maps (check_name, data_type) -> strategy_function +# For example: ("greater_than", pd.DataFrame) -> () +STRATEGY_DISPATCHER: Dict[Tuple[str, Type], Callable] = {} diff --git a/pandera/strategies/pandas_strategies.py b/pandera/strategies/pandas_strategies.py index 6a34241d2..205b861d6 100644 --- a/pandera/strategies/pandas_strategies.py +++ b/pandera/strategies/pandas_strategies.py @@ -41,6 +41,7 @@ ) from pandera.engines import numpy_engine, pandas_engine from pandera.errors import BaseStrategyOnlyError, SchemaDefinitionError +from pandera.strategies.base_strategies import STRATEGY_DISPATCHER try: import hypothesis @@ -775,8 +776,11 @@ def undefined_check_strategy(elements, check): ).filter(check._check_fn) for check in checks: - if check.strategy is not None: - elements = check.strategy(pandera_dtype, elements) + check_strategy = STRATEGY_DISPATCHER.get((check.name, pd.Series), None) + if check_strategy is not None: + elements = check_strategy( + pandera_dtype, elements, **check.statistics + ) elif check.element_wise: elements = undefined_check_strategy(elements, check) # NOTE: vectorized checks with undefined strategies should be handled @@ -847,7 +851,8 @@ def _check_fn(series): return strategy.filter(_check_fn) for check in checks if checks is not None else []: - if check.strategy is None and not check.element_wise: + check_strategy = STRATEGY_DISPATCHER.get((check.name, pd.Series), None) + if check_strategy is None and not check.element_wise: strategy = undefined_check_strategy(strategy, check) return strategy @@ -1004,8 +1009,13 @@ def _dataframe_check_fn(dataframe): def make_row_strategy(col, checks): strategy = None for check in checks: - if check.strategy is not None: - strategy = check.strategy(col.dtype, strategy) + check_strategy = STRATEGY_DISPATCHER.get( + (check.name, pd.DataFrame), None + ) + if check_strategy is not None: + strategy = check_strategy( + col.dtype, strategy, **check.statistics + ) else: strategy = undefined_check_strategy( strategy=( @@ -1024,7 +1034,10 @@ def _dataframe_strategy(draw): row_strategy_checks = [] undefined_strat_df_checks = [] for check in checks: - if check.strategy is not None or check.element_wise: + check_strategy = STRATEGY_DISPATCHER.get( + (check.name, pd.DataFrame) + ) + if check_strategy is not None or check.element_wise: # we can apply element-wise checks defined at the dataframe # level to the row strategy row_strategy_checks.append(check) @@ -1066,7 +1079,8 @@ def _dataframe_strategy(draw): undefined_strat_column_checks[col_name].extend( check for check in column.checks - if check.strategy is None and not check.element_wise + if STRATEGY_DISPATCHER.get((check.name, pd.DataFrame)) is None + and not check.element_wise ) # override the column datatype with dataframe-level datatype if diff --git a/tests/core/test_checks_builtin.py b/tests/core/test_checks_builtin.py index e97ece250..84e74ec8c 100644 --- a/tests/core/test_checks_builtin.py +++ b/tests/core/test_checks_builtin.py @@ -1,6 +1,7 @@ """Tests for builtin checks in pandera.core.checks.Check """ +import pickle from typing import Iterable import pandas as pd @@ -1034,3 +1035,23 @@ def test_unique_values_eq_raise_error(values): """Test that unique_values_eq raises an error arg is not iterable.""" with pytest.raises((TypeError, ValueError)): Check.unique_values_eq(values) + + +def test_check_pickling(tmp_path): + """Test that built-in checks can be pickled/unpickled correctly.""" + check = Check.gt(0) + valid_data = pd.Series([1, 2, 3]) + invalid_data = valid_data * -1 + + fp = tmp_path / "check.pickle" + with fp.open("wb") as f: + pickle.dump(check, f) + + with fp.open("rb") as f: + loaded_check = pickle.load(f) + + assert check == loaded_check + assert check(valid_data).check_passed + assert loaded_check(valid_data).check_passed + assert not check(invalid_data).check_passed + assert not loaded_check(invalid_data).check_passed diff --git a/tests/core/test_engine.py b/tests/core/test_engine.py index 6e38121c0..ddc450cab 100644 --- a/tests/core/test_engine.py +++ b/tests/core/test_engine.py @@ -32,7 +32,7 @@ def equivalents() -> List[Any]: @pytest.fixture def engine() -> Generator[Engine, None, None]: class FakeEngine( # pylint:disable=too-few-public-methods - metaclass=Engine, base_pandera_dtypes=BaseDataType + metaclass=Engine, base_pandera_dtypes=BaseDataType # type: ignore[call-arg] ): pass diff --git a/tests/core/test_extension_modules.py b/tests/core/test_extension_modules.py index 2241ff842..3b55063f5 100644 --- a/tests/core/test_extension_modules.py +++ b/tests/core/test_extension_modules.py @@ -10,7 +10,7 @@ def test_hypotheses_module_import() -> None: """Test that Hypothesis built-in methods raise import error.""" if not HAS_SCIPY: for fn in [ - lambda: Hypothesis.two_sample_ttest("sample1", "sample2"), + lambda: Hypothesis.two_sample_ttest("sample1", "sample2"), # type: ignore[arg-type] lambda: Hypothesis.one_sample_ttest(popmean=10), ]: with pytest.raises(ImportError): diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 10ecb892a..609a1b2dc 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -130,7 +130,7 @@ class SchemaWithAliasDtype(pa.DataFrameModel): with pytest.raises( pa.errors.SchemaInitError, match="Index 'idx' cannot be Optional." ): - model.to_schema() + model.to_schema() # type: ignore[attr-defined] def test_empty_dtype() -> None: diff --git a/tests/strategies/test_strategies.py b/tests/strategies/test_strategies.py index 8de322c38..f0e4f5731 100644 --- a/tests/strategies/test_strategies.py +++ b/tests/strategies/test_strategies.py @@ -13,7 +13,7 @@ import pandera as pa from pandera import strategies from pandera.core.checks import Check -from pandera.core.base.checks import register_check_statistics +from pandera.core.extensions import register_check_statistics from pandera.dtypes import is_category, is_complex, is_float from pandera.engines import pandas_engine