Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Unions in schemas and validation #1227

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Union,
cast,
)
from mypy.types import UnionType

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,7 +61,6 @@
else:
from typing_extensions import TypedDict # noqa


try:
from typing import Literal # type: ignore
except ImportError:
Expand Down Expand Up @@ -1330,3 +1330,39 @@

def __str__(self) -> str:
return str(NamedTuple.__name__)


@Engine.register_dtype(equivalents=[UnionType, Union, "UnionType", "union"])
@dtypes.immutable(init=True)
class PythonUnion(PythonGenericType):
"""A datatype to support python generics."""

type = UnionType

def __init__( # pylint:disable=super-init-not-called
self, generic_type: Optional[Type] = None
) -> None:
if generic_type is not None:
object.__setattr__(self, "generic_type", generic_type)

def check(
self,
pandera_dtype: dtypes.DataType,
data_container: Optional[PandasObject] = None,
) -> Union[bool, Iterable[bool]]:
"""Check that data container has the expected type."""
pandera_dtype = Engine.dtype(pandera_dtype)

pandas_types = [object]
pandas_types.extend(
self.generic_type.__args__ # pylint: disable=no-member
)

# the underlying pandas dtype must be an object
if pandera_dtype not in map(Engine.dtype, pandas_types):
return False

if data_container is None:

Check warning on line 1365 in pandera/engines/pandas_engine.py

View check run for this annotation

Codecov / codecov/patch

pandera/engines/pandas_engine.py#L1365

Added line #L1365 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karajan1001 can we add a test case for when data_container is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, @cosmicBboy .I'm sorry I do not know how to create a test case in which data_container is None.
I tried
"""
pd.DataDrame{[]}
"""
But it just put an empty pd.Series in.

return True
else:
return data_container.map(self._check_type) # type: ignore[operator]
39 changes: 33 additions & 6 deletions tests/core/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import sys
from decimal import Decimal
from typing import Any, Dict, List, NamedTuple, Tuple
from typing import Any, Dict, List, NamedTuple, Tuple, Union

import hypothesis
import numpy as np
Expand Down Expand Up @@ -733,17 +733,44 @@ class PointTuple(NamedTuple):
"tuple_column": pa.Column(Tuple[int, str, float]),
"typeddict_column": pa.Column(PointDict),
"namedtuple_column": pa.Column(PointTuple),
"column_union_float": pa.Column(Union[str, float]),
"column_union_str": pa.Column(Union[str, float]),
"column_union_obj": pa.Column(Union[str, float]),
},
)

data = pd.DataFrame(
{
"dict_column": [{"foo": 1, "bar": 2}],
"list_column": [[1.0]],
"tuple_column": [(1, "bar", 1.0)],
"typeddict_column": [PointDict(x=2.1, y=4.8)],
"namedtuple_column": [PointTuple(x=9.2, y=1.6)],
"dict_column": [{"foo": 1, "bar": 2}, {"foobar": 3}],
"list_column": [[1.0], [2.0]],
"tuple_column": [(1, "bar", 1.0), (2, "foobar", 2.0)],
"typeddict_column": [
PointDict(x=2.1, y=4.8),
PointDict(x=2.5, y=9.0),
],
"namedtuple_column": [
PointTuple(x=9.2, y=1.6),
PointTuple(x=2.5, y=1.4),
],
"column_union_float": [1.0, 2.0],
"column_union_str": ["foo", "bar"],
"column_union_obj": [12.0, "foo"],
}
)

schema.validate(data)

float_or_str_schema = pa.DataFrameSchema(
{
"column_union": pa.Column(Union[str, float]),
},
)

int_data = pd.DataFrame(
{
"column_union": [1, 2],
}
)

with pytest.raises(pa.errors.SchemaError):
float_or_str_schema.validate(int_data)
1 change: 1 addition & 0 deletions tests/strategies/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
pandas_engine.PythonTuple,
pandas_engine.PythonTypedDict,
pandas_engine.PythonNamedTuple,
pandas_engine.PythonUnion,
]
)
SUPPORTED_DTYPES = set()
Expand Down