Skip to content

Commit

Permalink
Handle timestamp and nans in removing multi index failure cases # 1469
Browse files Browse the repository at this point in the history
Signed-off-by: Rory <rory@rorymcstay.com>
  • Loading branch information
Rory McStay committed Aug 15, 2024
1 parent f6317d6 commit 1a4c817
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 1 deletion.
28 changes: 27 additions & 1 deletion pandera/backends/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from typing import List, Optional, TypeVar, Union

import numpy as np
import pandas as pd

from pandera.api.base.checks import CheckResult
Expand All @@ -28,6 +29,24 @@
SchemaWarning,
)


_MULTIINDEX_HANDLED_TYPES = {
"Timestamp": pd.Timestamp,
"NaT": pd.NaT,
"nan": np.nan,
}


class ColumnInfo(NamedTuple):
"""Column metadata used during validation."""

sorted_column_names: Iterable
expanded_column_names: FrozenSet
destuttered_column_names: List
absent_column_names: List
regex_match_patterns: List


FieldCheckObj = Union[pd.Series, pd.DataFrame]

T = TypeVar(
Expand Down Expand Up @@ -196,7 +215,14 @@ def drop_invalid_rows(self, check_obj, error_handler: ErrorHandler):
if isinstance(check_obj.index, pd.MultiIndex):
# MultiIndex values are saved on the error as strings so need to be cast back
# to their original types
index_tuples = err.failure_cases["index"].apply(eval)
index_tuples = (
err.failure_cases["index"]
.astype(str)
.apply(lambda i: eval(i, _MULTIINDEX_HANDLED_TYPES))
)
# type check on a column of index.
if len(index_tuples) == 1 and index_tuples[0] is None:
continue
index_values = pd.MultiIndex.from_tuples(index_tuples)

mask = ~check_obj.index.isin(index_values)
Expand Down
139 changes: 139 additions & 0 deletions tests/core/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2648,3 +2648,142 @@ def test_schema_column_default_handle_nans(
df = pd.DataFrame({"column1": [input_value]})
schema.validate(df, inplace=True)
assert df.iloc[0]["column1"] == default


@pytest.mark.parametrize(
"schema, obj, expected_obj, check_dtype",
[
(
DataFrameSchema(
columns={
"temperature": Column(float, nullable=False),
},
index=MultiIndex(
[
Index(pd.Timestamp, name="timestamp"),
Index(str, name="city"),
]
),
drop_invalid_rows=True,
),
pd.DataFrame(
{
"temperature": [
3.0,
4.0,
5.0,
5.0,
np.nan,
2.0,
],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2023-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
),
names=["timestamp", "city"],
),
),
pd.DataFrame(
{
"temperature": [3.0, 4.0, 5.0, 5.0, 2.0],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
),
names=["timestamp", "city"],
),
),
True,
),
(
DataFrameSchema(
columns={
"temperature": Column(float, nullable=False),
},
index=MultiIndex(
[
Index(pd.Timestamp, name="timestamp"),
Index(str, name="city"),
]
),
drop_invalid_rows=True,
),
pd.DataFrame(
{
"temperature": [
3.0,
4.0,
5.0,
-1.0,
np.nan,
-2.0,
4.0,
5.0,
2.0,
],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2023-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
(
pd.Timestamp("2024-01-01", tz="Europe/London"),
"London",
),
(pd.Timestamp(pd.NaT), "Frankfurt"),
(pd.Timestamp("2024-01-01"), 6),
),
names=["timestamp", "city"],
),
),
pd.DataFrame(
{
"temperature": [3.0, 4.0, 5.0, -1.0, -2.0, 4],
},
index=pd.MultiIndex.from_tuples(
(
(pd.Timestamp("2022-01-01"), "Paris"),
(pd.Timestamp("2023-01-01"), "Paris"),
(pd.Timestamp("2024-01-01"), "Paris"),
(pd.Timestamp("2022-01-01"), "Oslo"),
(pd.Timestamp("2024-01-01"), "Oslo"),
(
pd.Timestamp("2024-01-01", tz="Europe/London"),
"London",
),
),
names=["timestamp", "city"],
),
),
False,
),
],
)
def test_drop_invalid_for_multi_index_with_datetime(
schema, obj, expected_obj, check_dtype
):
"""Test drop_invalid_rows works as expected on multi-index dataframes"""
actual_obj = schema.validate(obj, lazy=True)

# the datatype of the index is not casted, In this cases its an object
pd.testing.assert_frame_equal(
actual_obj,
expected_obj,
check_dtype=check_dtype,
check_index_type=check_dtype,
)

0 comments on commit 1a4c817

Please sign in to comment.