From 8dd841405f93164422220782261e607e8da45afa Mon Sep 17 00:00:00 2001 From: Boris Smidt Date: Tue, 9 Aug 2022 02:43:55 +0200 Subject: [PATCH] Add from records to panderas dataframe #850 (#859) * Add a from record that checks the schema for a pandas dataframe * Add a from record that checks the schema for a pandas dataframe * handle nox session.install issue * fix lint * fix noxfile issue * remove unneeded types * update type annotation Co-authored-by: cosmicBboy --- pandera/typing/pandas.py | 43 ++++++++++++++++-- tests/core/test_model.py | 98 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 3 deletions(-) diff --git a/pandera/typing/pandas.py b/pandera/typing/pandas.py index cede3d129..e224b6156 100644 --- a/pandera/typing/pandas.py +++ b/pandera/typing/pandas.py @@ -1,9 +1,19 @@ """Typing definitions and helpers.""" # pylint:disable=abstract-method,disable=too-many-ancestors import io -from typing import _type_check # type: ignore[attr-defined] -from typing import TYPE_CHECKING, Any, Generic, TypeVar - +from typing import ( # type: ignore[attr-defined] + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Tuple, + TypeVar, + Union, + _type_check, +) + +import numpy as np import pandas as pd from ..errors import SchemaError, SchemaInitError @@ -176,3 +186,30 @@ def pydantic_validate(cls, obj: Any, field: ModelField) -> pd.DataFrame: raise ValueError(str(exc)) from exc return cls.to_format(valid_data, schema_model.__config__) + + @staticmethod + def from_records( + schema: T, + data: Union[ + np.ndarray, List[Tuple[Any, ...]], Dict[Any, Any], pd.DataFrame + ], + **kwargs, + ) -> "DataFrame[T]": + """ + Convert structured or record ndarray to pandera-validated DataFrame. + + Creates a DataFrame object from a structured ndarray, sequence of tuples + or dicts, or DataFrame. + + See :doc:`pandas:reference/api/pandas.DataFrame.from_records` for + more details. + """ + schema = schema.to_schema() # type: ignore[attr-defined] + schema_index = schema.index.names if schema.index is not None else None + if "index" not in kwargs: + kwargs["index"] = schema_index + return DataFrame[T]( + pd.DataFrame.from_records(data=data, **kwargs,)[ + schema.columns.keys() + ] # set the column order according to schema + ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index e69d535fb..413dae9ca 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1041,6 +1041,104 @@ class Config: DataFrame[SchemaNoCoerce](raw_data) +def test_from_records_validates_the_schema(): + """Test that DataFrame[Schema] validates the schema""" + + class Schema(pa.SchemaModel): + state: Series[str] + city: Series[str] + price: Series[float] + + raw_data = [ + { + "state": "NY", + "city": "New York", + "price": 8.0, + }, + { + "state": "FL", + "city": "Miami", + "price": 12.0, + }, + ] + pandera_validated_df = DataFrame.from_records(Schema, raw_data) + pandas_df = pd.DataFrame.from_records(raw_data) + assert pandera_validated_df.equals(Schema.validate(pandas_df)) + assert isinstance(pandera_validated_df, DataFrame) + assert isinstance(pandas_df, pd.DataFrame) + + raw_data = [ + { + "state": "NY", + "city": "New York", + }, + { + "state": "FL", + "city": "Miami", + }, + ] + + with pytest.raises( + pa.errors.SchemaError, + match="^column 'price' not in dataframe", + ): + DataFrame[Schema](raw_data) + + +def test_from_records_sets_the_index_from_schema(): + """Test that DataFrame[Schema] validates the schema""" + + class Schema(pa.SchemaModel): + state: Index[str] = pa.Field(check_name=True) + city: Series[str] + price: Series[float] + + raw_data = [ + { + "state": "NY", + "city": "New York", + "price": 8.0, + }, + { + "state": "FL", + "city": "Miami", + "price": 12.0, + }, + ] + pandera_validated_df = DataFrame.from_records(Schema, raw_data) + pandas_df = pd.DataFrame.from_records(raw_data, index=["state"]) + assert pandera_validated_df.equals(Schema.validate(pandas_df)) + assert isinstance(pandera_validated_df, DataFrame) + assert isinstance(pandas_df, pd.DataFrame) + + +def test_from_records_sorts_the_columns(): + """Test that DataFrame[Schema] validates the schema""" + + class Schema(pa.SchemaModel): + state: Series[str] + city: Series[str] + price: Series[float] + + raw_data = [ + { + "city": "New York", + "price": 8.0, + "state": "NY", + }, + { + "price": 12.0, + "state": "FL", + "city": "Miami", + }, + ] + pandera_validated_df = DataFrame.from_records(Schema, raw_data) + pandas_df = pd.DataFrame.from_records(raw_data)[["state", "city", "price"]] + assert pandera_validated_df.equals(Schema.validate(pandas_df)) + assert isinstance(pandera_validated_df, DataFrame) + assert isinstance(pandas_df, pd.DataFrame) + + def test_schema_model_generic_inheritance() -> None: """Test that a schema model subclass can also inherit from typing.Generic"""