Skip to content

Commit

Permalink
Add from records to panderas dataframe #850 (#859)
Browse files Browse the repository at this point in the history
* Add a from record that checks the schema for a pandas dataframe

* Add a from record that checks the schema for a pandas dataframe

* handle nox session.install issue

* fix lint

* fix noxfile issue

* remove unneeded types

* update type annotation

Co-authored-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
borissmidt and cosmicBboy committed Aug 10, 2022
1 parent bc6d8f6 commit 8dd8414
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 3 deletions.
43 changes: 40 additions & 3 deletions pandera/typing/pandas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
"""Typing definitions and helpers."""
# pylint:disable=abstract-method,disable=too-many-ancestors
import io
from typing import _type_check # type: ignore[attr-defined]
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from typing import ( # type: ignore[attr-defined]
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Tuple,
TypeVar,
Union,
_type_check,
)

import numpy as np
import pandas as pd

from ..errors import SchemaError, SchemaInitError
Expand Down Expand Up @@ -176,3 +186,30 @@ def pydantic_validate(cls, obj: Any, field: ModelField) -> pd.DataFrame:
raise ValueError(str(exc)) from exc

return cls.to_format(valid_data, schema_model.__config__)

@staticmethod
def from_records(
schema: T,
data: Union[
np.ndarray, List[Tuple[Any, ...]], Dict[Any, Any], pd.DataFrame
],
**kwargs,
) -> "DataFrame[T]":
"""
Convert structured or record ndarray to pandera-validated DataFrame.
Creates a DataFrame object from a structured ndarray, sequence of tuples
or dicts, or DataFrame.
See :doc:`pandas:reference/api/pandas.DataFrame.from_records` for
more details.
"""
schema = schema.to_schema() # type: ignore[attr-defined]
schema_index = schema.index.names if schema.index is not None else None
if "index" not in kwargs:
kwargs["index"] = schema_index
return DataFrame[T](
pd.DataFrame.from_records(data=data, **kwargs,)[
schema.columns.keys()
] # set the column order according to schema
)
98 changes: 98 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,104 @@ class Config:
DataFrame[SchemaNoCoerce](raw_data)


def test_from_records_validates_the_schema():
"""Test that DataFrame[Schema] validates the schema"""

class Schema(pa.SchemaModel):
state: Series[str]
city: Series[str]
price: Series[float]

raw_data = [
{
"state": "NY",
"city": "New York",
"price": 8.0,
},
{
"state": "FL",
"city": "Miami",
"price": 12.0,
},
]
pandera_validated_df = DataFrame.from_records(Schema, raw_data)
pandas_df = pd.DataFrame.from_records(raw_data)
assert pandera_validated_df.equals(Schema.validate(pandas_df))
assert isinstance(pandera_validated_df, DataFrame)
assert isinstance(pandas_df, pd.DataFrame)

raw_data = [
{
"state": "NY",
"city": "New York",
},
{
"state": "FL",
"city": "Miami",
},
]

with pytest.raises(
pa.errors.SchemaError,
match="^column 'price' not in dataframe",
):
DataFrame[Schema](raw_data)


def test_from_records_sets_the_index_from_schema():
"""Test that DataFrame[Schema] validates the schema"""

class Schema(pa.SchemaModel):
state: Index[str] = pa.Field(check_name=True)
city: Series[str]
price: Series[float]

raw_data = [
{
"state": "NY",
"city": "New York",
"price": 8.0,
},
{
"state": "FL",
"city": "Miami",
"price": 12.0,
},
]
pandera_validated_df = DataFrame.from_records(Schema, raw_data)
pandas_df = pd.DataFrame.from_records(raw_data, index=["state"])
assert pandera_validated_df.equals(Schema.validate(pandas_df))
assert isinstance(pandera_validated_df, DataFrame)
assert isinstance(pandas_df, pd.DataFrame)


def test_from_records_sorts_the_columns():
"""Test that DataFrame[Schema] validates the schema"""

class Schema(pa.SchemaModel):
state: Series[str]
city: Series[str]
price: Series[float]

raw_data = [
{
"city": "New York",
"price": 8.0,
"state": "NY",
},
{
"price": 12.0,
"state": "FL",
"city": "Miami",
},
]
pandera_validated_df = DataFrame.from_records(Schema, raw_data)
pandas_df = pd.DataFrame.from_records(raw_data)[["state", "city", "price"]]
assert pandera_validated_df.equals(Schema.validate(pandas_df))
assert isinstance(pandera_validated_df, DataFrame)
assert isinstance(pandas_df, pd.DataFrame)


def test_schema_model_generic_inheritance() -> None:
"""Test that a schema model subclass can also inherit from typing.Generic"""

Expand Down

0 comments on commit 8dd8414

Please sign in to comment.