From 830d77cbf06266a32ed4f30ff9028b5adac92cfa Mon Sep 17 00:00:00 2001 From: Matt Morris Date: Thu, 8 Dec 2022 15:22:38 -0600 Subject: [PATCH 1/4] test: add unit test for `pandera.io.to_pyarrow_field` --- pandera/io.py | 105 +++++++++++++++++++++++++++++++++++++++++--- setup.py | 2 +- tests/io/test_io.py | 86 +++++++++++++++++++++++++++++++++++- 3 files changed, 184 insertions(+), 9 deletions(-) diff --git a/pandera/io.py b/pandera/io.py index 10fb8989a..0ccf873e7 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -5,7 +5,7 @@ from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import pandas as pd @@ -13,19 +13,20 @@ from . import dtypes from .checks import Check -from .engines import pandas_engine -from .schema_components import Column +from .engines import numpy_engine, pandas_engine +from .schema_components import Column, Index, MultiIndex, SeriesSchemaBase from .schema_statistics import get_dataframe_schema_statistics from .schemas import DataFrameSchema try: import black + import pyarrow import yaml from frictionless import Schema as FrictionlessSchema except ImportError as exc: # pragma: no cover raise ImportError( - "IO and formatting requires 'pyyaml', 'black' and 'frictionless'" - "to be installed.\n" + "IO and formatting requires 'pyyaml', 'black', 'frictionless' and " + "`pyarrow` to be installed.\n" "You can install pandera together with the IO dependencies with:\n" "pip install pandera[io]\n" ) from exc @@ -246,8 +247,6 @@ def deserialize_schema(serialized_schema): :returns: the schema de-serialized into :class:`~pandera.schemas.DataFrameSchema` """ - # pylint: disable=import-outside-toplevel - from pandera import Index, MultiIndex # GH#475 serialized_schema = serialized_schema if serialized_schema else {} @@ -806,3 +805,95 @@ def from_frictionless_schema( ), } return deserialize_schema(assembled_schema) + + +def to_pyarrow_field(pandera_field: SeriesSchemaBase) -> pyarrow.Field: + """ + Convert a :class:`~pandera.schema_components.SeriesSchemaBase` to + ``pyarrow.Field`` + + :param pandera_field: pandera Index or Column + :returns: ``pyarrow.Field`` representation of ``pandera_field`` + """ + + pandera_dtype = pandera_field.dtype + pandas_dtype = pandas_engine.Engine.dtype(pandera_dtype).type + + pandas_types = { + pd.BooleanDtype(): pyarrow.bool_(), + pd.Int8Dtype(): pyarrow.int8(), + pd.Int16Dtype(): pyarrow.int16(), + pd.Int32Dtype(): pyarrow.int32(), + pd.Int64Dtype(): pyarrow.int64(), + pd.UInt8Dtype(): pyarrow.uint8(), + pd.UInt16Dtype(): pyarrow.uint16(), + pd.UInt32Dtype(): pyarrow.uint32(), + pd.UInt64Dtype(): pyarrow.uint64(), + pd.Float32Dtype(): pyarrow.float32(), # type: ignore[attr-defined] + pd.Float64Dtype(): pyarrow.float64(), # type: ignore[attr-defined] + pd.StringDtype(): pyarrow.string(), + } + + if pandas_dtype in pandas_types: + pyarrow_type = pandas_types[pandera_field.dtype.type] + elif isinstance( + pandera_dtype, (pandas_engine.Date, numpy_engine.DateTime64) + ): + pyarrow_type = pyarrow.date64() + elif isinstance(pandera_field.dtype, dtypes.Category): + # Categorical data types + pyarrow_type = pyarrow.dictionary( + pyarrow.int8(), + pandera_dtype.type.categories.inferred_type, + ordered=pandera_dtype.ordered, # type: ignore[attr-defined] + ) + else: + pyarrow_type = pyarrow.from_numpy_dtype(pandas_dtype) + + return pyarrow.field( + pandera_field.name, pyarrow_type, pandera_field.nullable + ) + + +def to_pyarrow_schema( + dataframe_schema: DataFrameSchema, + preserve_index: Optional[bool] = None, +) -> pyarrow.Schema: + """ + Convert a :class:`~pandera.schemas.DataFrameSchema` to ``pyarrow.Schema``. + + :param dataframe_schema: schema to convert to ``pyarrow.Schema`` + :param preserve_index: whether to store the index as an additional column + (or columns, for MultiIndex) in the resulting Table. The default of + None will store the index as a column, except for RangeIndex which is + stored as metadata only. Use ``preserve_index=True`` to force it to be + stored as a column. + :returns: ``pyarrow.Schema`` representation of DataFrameSchema + """ + + # List of columns that will be present in the pyarrow schema + columns: List[SeriesSchemaBase] = list(dataframe_schema.columns.values()) + + # pyarrow schema metadata + metadata: Dict[str, Any] = {} + + index = dataframe_schema.index + if index is None: + if preserve_index: + # Create column for RangeIndex + columns.append( + Index(dtypes.Int64, nullable=False, name="__index_level_0__") + ) + else: + # Only preserve metadata of index + metadata["index_columns"] = [ + {"kind": "range", "name": pyarrow.null, "step": 1} + ] + elif preserve_index is not False: + # Add column(s) for index(es) + if isinstance(index, Index): + columns.append(index) + elif isinstance(index, MultiIndex): + columns += index.indexes + + return pyarrow.Schema([to_pyarrow_field(c) for c in columns]) diff --git a/setup.py b/setup.py index 80fc0811f..be9d7f361 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ _extras_require = { "strategies": ["hypothesis >= 5.41.1"], "hypotheses": ["scipy"], - "io": ["pyyaml >= 5.1", "black", "frictionless"], + "io": ["pyyaml >= 5.1", "black", "frictionless", "pyarrow"], "pyspark": ["pyspark >= 3.2.0"], "modin": ["modin", "ray <= 1.7.0", "dask"], "modin-ray": ["modin", "ray <= 1.7.0"], diff --git a/tests/io/test_io.py b/tests/io/test_io.py index adadc482c..630957867 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -13,7 +13,9 @@ import pandera import pandera.extensions as pa_ext import pandera.typing as pat -from pandera.engines import pandas_engine +from pandera import dtypes +from pandera.engines import numpy_engine, pandas_engine +from pandera.schema_components import Column try: from pandera import io @@ -34,6 +36,14 @@ SKIP_YAML_TESTS = PYYAML_VERSION is None or PYYAML_VERSION.release < (5, 1, 0) # type: ignore +try: + import pyarrow +except ImportError: + SKIP_PYARROW_TESTS = True +else: + SKIP_PYARROW_TESTS = False + + # skip all tests in module if "io" depends aren't installed pytestmark = pytest.mark.skipif( not HAS_IO, reason='needs "io" module dependencies' @@ -1362,3 +1372,77 @@ def test_frictionless_schema_primary_key(frictionless_schema): assert schema.unique == frictionless_schema["primaryKey"] for key in frictionless_schema["primaryKey"]: assert not schema.columns[key].unique + + +@pytest.mark.skipif(SKIP_PYARROW_TESTS, reason="pyarrow required") +@pytest.mark.parametrize( + "pandera_dtype, expected_pyarrow_dtype", + [ + (dtypes.Bool, pyarrow.bool_()), + (numpy_engine.Bool, pyarrow.bool_()), + (pandas_engine.BOOL, pyarrow.bool_()), + (dtypes.Int8, pyarrow.int8()), + (dtypes.Int16, pyarrow.int16()), + (dtypes.Int32, pyarrow.int32()), + (dtypes.Int64, pyarrow.int64()), + (numpy_engine.Int8, pyarrow.int8()), + (numpy_engine.Int16, pyarrow.int16()), + (numpy_engine.Int32, pyarrow.int32()), + (numpy_engine.Int64, pyarrow.int64()), + (pandas_engine.INT8, pyarrow.int8()), + (pandas_engine.INT16, pyarrow.int16()), + (pandas_engine.INT32, pyarrow.int32()), + (pandas_engine.INT64, pyarrow.int64()), + (dtypes.UInt8, pyarrow.uint8()), + (dtypes.UInt16, pyarrow.uint16()), + (dtypes.UInt32, pyarrow.uint32()), + (dtypes.UInt64, pyarrow.uint64()), + (numpy_engine.UInt8, pyarrow.uint8()), + (numpy_engine.UInt16, pyarrow.uint16()), + (numpy_engine.UInt32, pyarrow.uint32()), + (numpy_engine.UInt64, pyarrow.uint64()), + (pandas_engine.UINT8, pyarrow.uint8()), + (pandas_engine.UINT16, pyarrow.uint16()), + (pandas_engine.UINT32, pyarrow.uint32()), + (pandas_engine.UINT64, pyarrow.uint64()), + (dtypes.Float16, pyarrow.float16()), + (dtypes.Float32, pyarrow.float32()), + (dtypes.Float64, pyarrow.float64()), + (numpy_engine.Float16, pyarrow.float16()), + (numpy_engine.Float32, pyarrow.float32()), + (numpy_engine.Float64, pyarrow.float64()), + (pandas_engine.FLOAT32, pyarrow.float32()), + (pandas_engine.FLOAT64, pyarrow.float64()), + (dtypes.String, pyarrow.string()), + (numpy_engine.String, pyarrow.string()), + (pandas_engine.STRING, pyarrow.string()), + (pandas_engine.NpString, pyarrow.string()), + (numpy_engine.Bytes, pyarrow.binary()), + (dtypes.Date, pyarrow.date64()), + (pandas_engine.Date, pyarrow.date64()), + (dtypes.Timestamp, pyarrow.timestamp("ns")), + (numpy_engine.DateTime64, pyarrow.date64()), # unbound + (pandas_engine.DateTime, pyarrow.timestamp("ns")), + (dtypes.Timedelta, pyarrow.duration("ns")), + (numpy_engine.Timedelta64, pyarrow.duration("ns")), + ( + dtypes.Category(categories=["foo", "bar", "baz"], ordered=True), + pyarrow.dictionary( + pyarrow.int8(), + pyarrow.string(), + ordered=True, + ), + ), + ], +) +@pytest.mark.parametrize("nullable", [True, False]) +def test_to_pyarrow_field(pandera_dtype, nullable, expected_pyarrow_dtype): + """Test if pandera_dtype is correctly converted to pyarrow dtype""" + name = "foo" + + pandera_field = Column(pandera_dtype, nullable=nullable, name=name) + pyarrow_dtype = io.to_pyarrow_field(pandera_field) + + assert pyarrow_dtype.type == expected_pyarrow_dtype + assert pyarrow_dtype.name == name + assert pyarrow_dtype.nullable == nullable From 6a7d9fb1208de06e472cb352cf97b97f3c38c89f Mon Sep 17 00:00:00 2001 From: Matt Morris Date: Thu, 8 Dec 2022 16:19:09 -0600 Subject: [PATCH 2/4] test: add test for `pandera.io.to_pyarrow_schema` --- pandera/io.py | 50 +++++++++++------ tests/io/test_io.py | 130 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 18 deletions(-) diff --git a/pandera/io.py b/pandera/io.py index 0ccf873e7..d2b108156 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -5,8 +5,9 @@ from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Dict, Optional, Union +import numpy as np import pandas as pd import pandera.errors @@ -807,9 +808,12 @@ def from_frictionless_schema( return deserialize_schema(assembled_schema) -def to_pyarrow_field(pandera_field: SeriesSchemaBase) -> pyarrow.Field: +def to_pyarrow_field( + name: str, + pandera_field: SeriesSchemaBase, +) -> pyarrow.Field: """ - Convert a :class:`~pandera.schema_components.SeriesSchemaBase` to + Convert a :class:`~pandera.schema_components.SeriesSchemaBase` to a ``pyarrow.Field`` :param pandera_field: pandera Index or Column @@ -847,12 +851,17 @@ def to_pyarrow_field(pandera_field: SeriesSchemaBase) -> pyarrow.Field: pandera_dtype.type.categories.inferred_type, ordered=pandera_dtype.ordered, # type: ignore[attr-defined] ) + elif pandas_dtype.type == np.object_: + pyarrow_type = pyarrow.string() else: pyarrow_type = pyarrow.from_numpy_dtype(pandas_dtype) - return pyarrow.field( - pandera_field.name, pyarrow_type, pandera_field.nullable - ) + return pyarrow.field(name, pyarrow_type, pandera_field.nullable) + + +def _get_index_name(level: int) -> str: + """Generate an index name for pyarrow if none is specified""" + return f"__index_level_{level}__" def to_pyarrow_schema( @@ -872,28 +881,37 @@ def to_pyarrow_schema( """ # List of columns that will be present in the pyarrow schema - columns: List[SeriesSchemaBase] = list(dataframe_schema.columns.values()) + columns: Dict[str, SeriesSchemaBase] = dataframe_schema.columns # type: ignore[assignment] # pyarrow schema metadata - metadata: Dict[str, Any] = {} + metadata: Dict[str, bytes] = {} index = dataframe_schema.index if index is None: if preserve_index: # Create column for RangeIndex - columns.append( - Index(dtypes.Int64, nullable=False, name="__index_level_0__") - ) + name = _get_index_name(0) + columns[name] = Index(dtypes.Int64, nullable=False, name=name) else: # Only preserve metadata of index - metadata["index_columns"] = [ + metadata[ + "index_columns" + ] = b"""[ {"kind": "range", "name": pyarrow.null, "step": 1} - ] + ]""" elif preserve_index is not False: # Add column(s) for index(es) if isinstance(index, Index): - columns.append(index) + name = index.name or _get_index_name(0) + # Ensure index is added at dictionary beginning + columns = {**{name: index}, **columns} + elif isinstance(index, MultiIndex): - columns += index.indexes + for i, value in enumerate(reversed(index.indexes)): + name = value.name or _get_index_name(i) + columns = {**{name: value}, **columns} - return pyarrow.Schema([to_pyarrow_field(c) for c in columns]) + return pyarrow.schema( + [to_pyarrow_field(k, v) for k, v in columns.items()], + metadata=metadata, + ) diff --git a/tests/io/test_io.py b/tests/io/test_io.py index 630957867..7c8f3449c 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -4,6 +4,7 @@ import tempfile from io import StringIO from pathlib import Path +from typing import Type from unittest import mock import pandas as pd @@ -1417,6 +1418,7 @@ def test_frictionless_schema_primary_key(frictionless_schema): (numpy_engine.String, pyarrow.string()), (pandas_engine.STRING, pyarrow.string()), (pandas_engine.NpString, pyarrow.string()), + (numpy_engine.Object, pyarrow.string()), (numpy_engine.Bytes, pyarrow.binary()), (dtypes.Date, pyarrow.date64()), (pandas_engine.Date, pyarrow.date64()), @@ -1436,13 +1438,137 @@ def test_frictionless_schema_primary_key(frictionless_schema): ], ) @pytest.mark.parametrize("nullable", [True, False]) -def test_to_pyarrow_field(pandera_dtype, nullable, expected_pyarrow_dtype): +def test_to_pyarrow_field( + pandera_dtype: Type[dtypes.DataType], + nullable: bool, + expected_pyarrow_dtype: pyarrow.DataType, +): """Test if pandera_dtype is correctly converted to pyarrow dtype""" name = "foo" pandera_field = Column(pandera_dtype, nullable=nullable, name=name) - pyarrow_dtype = io.to_pyarrow_field(pandera_field) + pyarrow_dtype = io.to_pyarrow_field(name, pandera_field) assert pyarrow_dtype.type == expected_pyarrow_dtype assert pyarrow_dtype.name == name assert pyarrow_dtype.nullable == nullable + + +@pytest.mark.skipif(SKIP_PYARROW_TESTS, reason="pyarrow required") +@pytest.mark.parametrize( + "dataframe_schema, preserve_index, expected", + [ + ( + _create_schema("single"), + True, + pyarrow.schema( + [ + pyarrow.field("__index_level_0__", pyarrow.int64(), False), + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema(None), + None, + pyarrow.schema( + [ + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema("multi"), + None, + pyarrow.schema( + [ + pyarrow.field("int_index0", pyarrow.int64(), False), + pyarrow.field("int_index1", pyarrow.int64(), False), + pyarrow.field("int_index2", pyarrow.int64(), False), + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema("multi"), + False, + pyarrow.schema( + [ + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field( + "datetime_column", pyarrow.timestamp("ns"), False + ), + pyarrow.field( + "timedelta_column", pyarrow.duration("ns"), False + ), + pyarrow.field( + "optional_props_column", pyarrow.string(), True + ), + ] + ), + ), + ( + _create_schema_python_types(), + None, + pyarrow.schema( + [ + pyarrow.field("int_column", pyarrow.int64(), False), + pyarrow.field("float_column", pyarrow.float64(), False), + pyarrow.field("str_column", pyarrow.string(), False), + pyarrow.field("object_column", pyarrow.string(), False), + ] + ), + ), + ], +) +def test_to_pyarrow_schema( + dataframe_schema: pandera.schemas.DataFrameSchema, + preserve_index: bool, + expected: pyarrow.Schema, +): + """Test if pandera schema is correctly converted to pyarrow.Schema""" + + # Drop column with no dtype specified + dataframe_schema.columns = { + k: v + for k, v in dataframe_schema.columns.items() + if k != "notype_column" + } + + pyarrow_schema = io.to_pyarrow_schema(dataframe_schema, preserve_index) + assert expected.equals(pyarrow_schema) From 559cdde4a5792d05b189bf5e79d5da1113f8f802 Mon Sep 17 00:00:00 2001 From: Matt Morris Date: Thu, 8 Dec 2022 16:20:46 -0600 Subject: [PATCH 3/4] style: make line more readable --- pandera/io.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pandera/io.py b/pandera/io.py index d2b108156..503567316 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -894,11 +894,8 @@ def to_pyarrow_schema( columns[name] = Index(dtypes.Int64, nullable=False, name=name) else: # Only preserve metadata of index - metadata[ - "index_columns" - ] = b"""[ - {"kind": "range", "name": pyarrow.null, "step": 1} - ]""" + meta_val = b'[{"kind": "range", "name": pyarrow.null, "step": 1}]' + metadata["index_columns"] = meta_val elif preserve_index is not False: # Add column(s) for index(es) if isinstance(index, Index): From a5bb4e4b33be51d5bff17bb66e575d7126b09be8 Mon Sep 17 00:00:00 2001 From: Matt Morris Date: Thu, 8 Dec 2022 16:30:51 -0600 Subject: [PATCH 4/4] style: make lines more readable --- tests/io/test_io.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/io/test_io.py b/tests/io/test_io.py index 7c8f3449c..24e92d1ba 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -1423,17 +1423,13 @@ def test_frictionless_schema_primary_key(frictionless_schema): (dtypes.Date, pyarrow.date64()), (pandas_engine.Date, pyarrow.date64()), (dtypes.Timestamp, pyarrow.timestamp("ns")), - (numpy_engine.DateTime64, pyarrow.date64()), # unbound + (numpy_engine.DateTime64, pyarrow.date64()), (pandas_engine.DateTime, pyarrow.timestamp("ns")), (dtypes.Timedelta, pyarrow.duration("ns")), (numpy_engine.Timedelta64, pyarrow.duration("ns")), ( dtypes.Category(categories=["foo", "bar", "baz"], ordered=True), - pyarrow.dictionary( - pyarrow.int8(), - pyarrow.string(), - ordered=True, - ), + pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=True), ), ], )