Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to_yaml serialization dropping global checks #428

Merged
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/schema_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ is a convenience method for this functionality.
coerce: false
required: true
regex: false
checks: null
index:
- pandas_dtype: int64
nullable: false
Expand Down
70 changes: 61 additions & 9 deletions pandera/checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
"""Data validation checks."""
# pylint: disable=fixme

import inspect
import operator
import re
from collections import namedtuple
from collections import ChainMap, namedtuple
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from itertools import chain
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
no_type_check,
)

import pandas as pd

Expand Down Expand Up @@ -397,9 +410,11 @@ def __call__(
)

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented

are_check_fn_objects_equal = (
self.__dict__["_check_fn"].__code__.co_code
== other.__dict__["_check_fn"].__code__.co_code
self._get_check_fn_code() == other._get_check_fn_code()
)

try:
Expand Down Expand Up @@ -427,8 +442,18 @@ def __eq__(self, other):
and are_all_other_check_attributes_equal
)

def _get_check_fn_code(self):
check_fn = self.__dict__["_check_fn"]
try:
code = check_fn.__code__.co_code
except AttributeError:
# try accessing the functools.partial wrapper
code = check_fn.func.__code__.co_code

return code

def __hash__(self):
return hash(self.__dict__["_check_fn"].__code__.co_code)
return hash(self._get_check_fn_code())

def __repr__(self):
return (
Expand All @@ -438,22 +463,49 @@ def __repr__(self):
)


_T = TypeVar("_T", bound=_CheckBase)


class _CheckMeta(type): # pragma: no cover
"""Check metaclass."""

# FIXME: this should probably just be moved to _CheckBase

REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why it'd be better in CheckBase ? REGISTERED_CUSTOM_CHECKS is referenced by __getattr__, it seems to make sense to have them sitting next to each others.

Copy link
Contributor Author

@antonl antonl Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment got moved by black. Essentially, I think we could merge the _CheckMeta with _CheckBase. The metaclass actually confuses mypy, and I don't see the benefit of using this mixin pattern here. Since this is a style choice, I didn't want to just do it though, so I added a comment so that we could discuss.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you move __getattr__ to _CheckBase it's going to act on instances of the class and not on the class itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Metaclasses confuse me, had to look up the descriptor protocol docs again. 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'm proposing we just move the metaclass to _CheckBase. It's not usable on anything else anyway because it requires a name property for example. I'll fix it and remove the fixme comment.


def __getattr__(cls, name: str) -> Any:
"""Prevent attribute errors for registered checks."""
attr = cls.__dict__.get(name)
attr = ChainMap(cls.__dict__, cls.REGISTERED_CUSTOM_CHECKS).get(name)
if attr is None:
raise AttributeError(f"'{cls}' object has no attribute '{name}'")
raise AttributeError(
f"'{cls}' object has no attribute '{name}'. "
"Make sure any custom checks have been registered "
"using the extensions api."
)
return attr

def __dir__(cls) -> Iterable[str]:
"""Allow custom checks to show up as attributes when autocompleting."""
return chain(super().__dir__(), cls.REGISTERED_CUSTOM_CHECKS.keys())

# pylint: disable=line-too-long
# mypy has limited metaclass support so this doesn't pass typecheck
# see https://mypy.readthedocs.io/en/stable/metaclasses.html#gotchas-and-limitations-of-metaclass-support
# pylint: enable=line-too-long
@no_type_check
def __contains__(cls: Type[_T], item: Union[_T, str]) -> bool:
"""Allow lookups for registered checks."""
if isinstance(item, cls):
name = item.name
return hasattr(cls, name)

# assume item is str
return hasattr(cls, item)


class Check(_CheckBase, metaclass=_CheckMeta):
"""Check a pandas Series or DataFrame for certain properties."""

REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa

@classmethod
@st.register_check_strategy(st.eq_strategy)
@register_check_statistics(["value"])
Expand Down
5 changes: 2 additions & 3 deletions pandera/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ def check_method(cls, *args, **kwargs):
if strategy is not None:
check_method = st.register_check_strategy(strategy)(check_method)

setattr(Check, check_fn.__name__, classmethod(check_method))
Check.REGISTERED_CUSTOM_CHECKS[check_fn.__name__] = getattr(
Check, check_fn.__name__
Check.REGISTERED_CUSTOM_CHECKS[check_fn.__name__] = partial(
check_method, Check
)

return register_check_wrapper(check_fn)
58 changes: 45 additions & 13 deletions pandera/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,60 @@
NOT_JSON_SERIALIZABLE = {PandasDtype.DateTime, PandasDtype.Timedelta}


def _serialize_check_stats(check_stats, pandas_dtype):
def _serialize_check_stats(check_stats, pandas_dtype=None):
"""Serialize check statistics into json/yaml-compatible format."""

def handle_stat_dtype(stat):
def handle_stat_dtype(stat, pandas_dtype):
if pandas_dtype == PandasDtype.DateTime:
return stat.strftime(DATETIME_FORMAT)
elif pandas_dtype == PandasDtype.Timedelta:
# serialize to int in nanoseconds
return stat.delta

return stat

# for unary checks, return a single value instead of a dictionary
if len(check_stats) == 1:
return handle_stat_dtype(list(check_stats.values())[0])
return handle_stat_dtype(list(check_stats.values())[0], pandas_dtype)

# otherwise return a dictionary of keyword args needed to create the Check
serialized_check_stats = {}
for arg, stat in check_stats.items():
serialized_check_stats[arg] = handle_stat_dtype(stat)
serialized_check_stats[arg] = handle_stat_dtype(stat, pandas_dtype)
return serialized_check_stats


def _serialize_dataframe_stats(dataframe_checks):
"""
Serialize global dataframe check statistics into json/yaml-compatible format.
"""
serialized_checks = {}

for check_name, check_stats in dataframe_checks.items():
# The case that `check_name` is not registered is handled in `parse_checks`,
# so we know that `check_name` exists.

# infer dtype of statistics and serialize them
serialized_checks[check_name] = _serialize_check_stats(check_stats)

return serialized_checks


def _serialize_component_stats(component_stats):
"""
Serialize column or index statistics into json/yaml-compatible format.
"""
# pylint: disable=import-outside-toplevel
from pandera.checks import Check

serialized_checks = None
if component_stats["checks"] is not None:
serialized_checks = {}
for check_name, check_stats in component_stats["checks"].items():
if check_stats is None:
if check_name not in Check:
warnings.warn(
f"Check {check_name} cannot be serialized. This check will be "
f"ignored"
"ignored. Did you forget to register it with the extension API?"
)
else:
serialized_checks[check_name] = _serialize_check_stats(
Expand Down Expand Up @@ -93,7 +113,7 @@ def _serialize_schema(dataframe_schema):

statistics = get_dataframe_schema_statistics(dataframe_schema)

columns, index = None, None
columns, index, checks = None, None, None
if statistics["columns"] is not None:
columns = {
col_name: _serialize_component_stats(column_stats)
Expand All @@ -106,18 +126,22 @@ def _serialize_schema(dataframe_schema):
for index_stats in statistics["index"]
]

if statistics["checks"] is not None:
checks = _serialize_dataframe_stats(statistics["checks"])

return {
"schema_type": "dataframe",
"version": __version__,
"columns": columns,
"checks": checks,
"index": index,
"coerce": dataframe_schema.coerce,
"strict": dataframe_schema.strict,
}


def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype):
def handle_stat_dtype(stat):
def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype=None):
def handle_stat_dtype(stat, pandas_dtype):
jeffzi marked this conversation as resolved.
Show resolved Hide resolved
if pandas_dtype == PandasDtype.DateTime:
return pd.to_datetime(stat, format=DATETIME_FORMAT)
elif pandas_dtype == PandasDtype.Timedelta:
Expand All @@ -130,10 +154,10 @@ def handle_stat_dtype(stat):
# dictionary mapping Check arg names to values.
check_stats = {}
for arg, stat in serialized_check_stats.items():
check_stats[arg] = handle_stat_dtype(stat)
check_stats[arg] = handle_stat_dtype(stat, pandas_dtype)
return check(**check_stats)
# otherwise assume unary check function signature
return check(handle_stat_dtype(serialized_check_stats))
return check(handle_stat_dtype(serialized_check_stats, pandas_dtype))


def _deserialize_component_stats(serialized_component_stats):
Expand Down Expand Up @@ -173,9 +197,9 @@ def _deserialize_component_stats(serialized_component_stats):

def _deserialize_schema(serialized_schema):
# pylint: disable=import-outside-toplevel
from pandera import Column, DataFrameSchema, Index, MultiIndex
from pandera import Check, Column, DataFrameSchema, Index, MultiIndex

columns, index = None, None
columns, index, checks = None, None, None
if serialized_schema["columns"] is not None:
columns = {
col_name: Column(**_deserialize_component_stats(column_stats))
Expand All @@ -188,6 +212,13 @@ def _deserialize_schema(serialized_schema):
for index_component in serialized_schema["index"]
]

if serialized_schema["checks"] is not None:
# handles unregistered checks by raising AttributeErrors from getattr
checks = [
_deserialize_check_stats(getattr(Check, check_name), check_stats)
for check_name, check_stats in serialized_schema["checks"].items()
]

if index is None:
pass
elif len(index) == 1:
Expand All @@ -199,6 +230,7 @@ def _deserialize_schema(serialized_schema):

return DataFrameSchema(
columns=columns,
checks=checks,
index=index,
coerce=serialized_schema["coerce"],
strict=serialized_schema["strict"],
Expand Down
8 changes: 8 additions & 0 deletions pandera/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Class-based api"""
import inspect
import os
import re
import sys
import typing
Expand Down Expand Up @@ -170,6 +171,13 @@ def to_schema(cls) -> DataFrameSchema:
MODEL_CACHE[cls] = cls.__schema__
return cls.__schema__

@classmethod
def to_yaml(cls, stream: Optional[os.PathLike] = None):
"""
Convert `Schema` to yaml using `io.to_yaml`.
"""
return cls.to_schema().to_yaml(stream)

@classmethod
@pd.util.Substitution(validate_doc=DataFrameSchema.validate.__doc__)
def validate(
Expand Down
13 changes: 12 additions & 1 deletion pandera/schema_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_dataframe_schema_statistics(dataframe_schema):
}
for col_name, column in dataframe_schema.columns.items()
},
"checks": parse_checks(dataframe_schema.checks),
"index": (
None
if dataframe_schema.index is None
Expand Down Expand Up @@ -158,7 +159,17 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]:
check_statistics = {}
_check_memo = {}
for check in checks:
check_statistics[check.name] = check.statistics
if check not in Check:
warnings.warn(
"Only registered checks may be serialized to statistics. "
"Did you forget to register it with the extension API? "
f"Check `{check.name}` will be skipped."
)
continue

check_statistics[check.name] = (
{} if check.statistics is None else check.statistics
)
_check_memo[check.name] = check

# raise ValueError on incompatible checks
Expand Down
6 changes: 3 additions & 3 deletions pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy
import itertools
import os
import warnings
from functools import wraps
from pathlib import Path
Expand Down Expand Up @@ -1186,17 +1187,16 @@ def from_yaml(cls, yaml_schema) -> "DataFrameSchema":

return pandera.io.from_yaml(yaml_schema)

def to_yaml(self, fp: Union[str, Path] = None):
def to_yaml(self, stream: Optional[os.PathLike] = None):
"""Write DataFrameSchema to yaml file.

:param dataframe_schema: schema to write to file or dump to string.
:param stream: file stream to write to. If None, dumps to string.
:returns: yaml string if stream is None, otherwise returns None.
"""
# pylint: disable=import-outside-toplevel,cyclic-import
import pandera.io

return pandera.io.to_yaml(self, fp)
return pandera.io.to_yaml(self, stream=stream)

def set_index(
self, keys: List[str], drop: bool = True, append: bool = False
Expand Down
33 changes: 33 additions & 0 deletions tests/core/checks_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Pytest fixtures for testing custom checks."""
import unittest.mock as mock

import pandas as pd
import pytest

import pandera as pa
import pandera.extensions as pa_ext

__all__ = "custom_check_teardown", "extra_registered_checks"


@pytest.fixture(scope="function")
def custom_check_teardown():
"""Remove all custom checks after execution of each pytest function."""
yield
for check_name in list(pa.Check.REGISTERED_CUSTOM_CHECKS):
del pa.Check.REGISTERED_CUSTOM_CHECKS[check_name]


@pytest.fixture(scope="function")
def extra_registered_checks():
"""temporarily registers custom checks onto the Check class"""
# pylint: disable=unused-variable
with mock.patch(
"pandera.Check.REGISTERED_CUSTOM_CHECKS", new_callable=dict
):
# register custom checks here
@pa_ext.register_check_method()
def no_param_check(_: pd.DataFrame) -> bool:
return True

yield
4 changes: 4 additions & 0 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Registers fixtures for core"""

# pylint: disable=unused-import
from .checks_fixtures import custom_check_teardown, extra_registered_checks
Loading