Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add pandera.io.to_pyarrow_schema #1047

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 112 additions & 6 deletions pandera/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,27 @@
from pathlib import Path
from typing import Dict, Optional, Union

import numpy as np
import pandas as pd

import pandera.errors

from . import dtypes
from .checks import Check
from .engines import pandas_engine
from .schema_components import Column
from .engines import numpy_engine, pandas_engine
from .schema_components import Column, Index, MultiIndex, SeriesSchemaBase
from .schema_statistics import get_dataframe_schema_statistics
from .schemas import DataFrameSchema

try:
import black
import pyarrow
import yaml
from frictionless import Schema as FrictionlessSchema
except ImportError as exc: # pragma: no cover
raise ImportError(
"IO and formatting requires 'pyyaml', 'black' and 'frictionless'"
"to be installed.\n"
"IO and formatting requires 'pyyaml', 'black', 'frictionless' and "
"`pyarrow` to be installed.\n"
"You can install pandera together with the IO dependencies with:\n"
"pip install pandera[io]\n"
) from exc
Expand Down Expand Up @@ -246,8 +248,6 @@ def deserialize_schema(serialized_schema):
:returns:
the schema de-serialized into :class:`~pandera.schemas.DataFrameSchema`
"""
# pylint: disable=import-outside-toplevel
from pandera import Index, MultiIndex

# GH#475
serialized_schema = serialized_schema if serialized_schema else {}
Expand Down Expand Up @@ -806,3 +806,109 @@ def from_frictionless_schema(
),
}
return deserialize_schema(assembled_schema)


def to_pyarrow_field(
name: str,
pandera_field: SeriesSchemaBase,
) -> pyarrow.Field:
"""
Convert a :class:`~pandera.schema_components.SeriesSchemaBase` to a
``pyarrow.Field``

:param pandera_field: pandera Index or Column
:returns: ``pyarrow.Field`` representation of ``pandera_field``
"""

pandera_dtype = pandera_field.dtype
pandas_dtype = pandas_engine.Engine.dtype(pandera_dtype).type

pandas_types = {
pd.BooleanDtype(): pyarrow.bool_(),
pd.Int8Dtype(): pyarrow.int8(),
pd.Int16Dtype(): pyarrow.int16(),
pd.Int32Dtype(): pyarrow.int32(),
pd.Int64Dtype(): pyarrow.int64(),
pd.UInt8Dtype(): pyarrow.uint8(),
pd.UInt16Dtype(): pyarrow.uint16(),
pd.UInt32Dtype(): pyarrow.uint32(),
pd.UInt64Dtype(): pyarrow.uint64(),
pd.Float32Dtype(): pyarrow.float32(), # type: ignore[attr-defined]
pd.Float64Dtype(): pyarrow.float64(), # type: ignore[attr-defined]
pd.StringDtype(): pyarrow.string(),
}

if pandas_dtype in pandas_types:
pyarrow_type = pandas_types[pandera_field.dtype.type]
elif isinstance(
pandera_dtype, (pandas_engine.Date, numpy_engine.DateTime64)
):
pyarrow_type = pyarrow.date64()
elif isinstance(pandera_field.dtype, dtypes.Category):
# Categorical data types
pyarrow_type = pyarrow.dictionary(
pyarrow.int8(),
pandera_dtype.type.categories.inferred_type,
ordered=pandera_dtype.ordered, # type: ignore[attr-defined]
)
elif pandas_dtype.type == np.object_:
pyarrow_type = pyarrow.string()
else:
pyarrow_type = pyarrow.from_numpy_dtype(pandas_dtype)

return pyarrow.field(name, pyarrow_type, pandera_field.nullable)


def _get_index_name(level: int) -> str:
"""Generate an index name for pyarrow if none is specified"""
return f"__index_level_{level}__"


def to_pyarrow_schema(
dataframe_schema: DataFrameSchema,
preserve_index: Optional[bool] = None,
) -> pyarrow.Schema:
"""
Convert a :class:`~pandera.schemas.DataFrameSchema` to ``pyarrow.Schema``.

:param dataframe_schema: schema to convert to ``pyarrow.Schema``
:param preserve_index: whether to store the index as an additional column
(or columns, for MultiIndex) in the resulting Table. The default of
None will store the index as a column, except for RangeIndex which is
stored as metadata only. Use ``preserve_index=True`` to force it to be
stored as a column.
:returns: ``pyarrow.Schema`` representation of DataFrameSchema
"""

# List of columns that will be present in the pyarrow schema
columns: Dict[str, SeriesSchemaBase] = dataframe_schema.columns # type: ignore[assignment]

# pyarrow schema metadata
metadata: Dict[str, bytes] = {}

index = dataframe_schema.index
if index is None:
if preserve_index:
# Create column for RangeIndex
name = _get_index_name(0)
columns[name] = Index(dtypes.Int64, nullable=False, name=name)
else:
# Only preserve metadata of index
meta_val = b'[{"kind": "range", "name": pyarrow.null, "step": 1}]'
metadata["index_columns"] = meta_val
elif preserve_index is not False:
# Add column(s) for index(es)
if isinstance(index, Index):
name = index.name or _get_index_name(0)
# Ensure index is added at dictionary beginning
columns = {**{name: index}, **columns}

elif isinstance(index, MultiIndex):
for i, value in enumerate(reversed(index.indexes)):
name = value.name or _get_index_name(i)
columns = {**{name: value}, **columns}

return pyarrow.schema(
[to_pyarrow_field(k, v) for k, v in columns.items()],
metadata=metadata,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_extras_require = {
"strategies": ["hypothesis >= 5.41.1"],
"hypotheses": ["scipy"],
"io": ["pyyaml >= 5.1", "black", "frictionless"],
"io": ["pyyaml >= 5.1", "black", "frictionless", "pyarrow"],
"pyspark": ["pyspark >= 3.2.0"],
"modin": ["modin", "ray <= 1.7.0", "dask"],
"modin-ray": ["modin", "ray <= 1.7.0"],
Expand Down
208 changes: 207 additions & 1 deletion tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
from io import StringIO
from pathlib import Path
from typing import Type
from unittest import mock

import pandas as pd
Expand All @@ -13,7 +14,9 @@
import pandera
import pandera.extensions as pa_ext
import pandera.typing as pat
from pandera.engines import pandas_engine
from pandera import dtypes
from pandera.engines import numpy_engine, pandas_engine
from pandera.schema_components import Column

try:
from pandera import io
Expand All @@ -34,6 +37,14 @@
SKIP_YAML_TESTS = PYYAML_VERSION is None or PYYAML_VERSION.release < (5, 1, 0) # type: ignore


try:
import pyarrow
except ImportError:
SKIP_PYARROW_TESTS = True
else:
SKIP_PYARROW_TESTS = False


# skip all tests in module if "io" depends aren't installed
pytestmark = pytest.mark.skipif(
not HAS_IO, reason='needs "io" module dependencies'
Expand Down Expand Up @@ -1362,3 +1373,198 @@ def test_frictionless_schema_primary_key(frictionless_schema):
assert schema.unique == frictionless_schema["primaryKey"]
for key in frictionless_schema["primaryKey"]:
assert not schema.columns[key].unique


@pytest.mark.skipif(SKIP_PYARROW_TESTS, reason="pyarrow required")
@pytest.mark.parametrize(
"pandera_dtype, expected_pyarrow_dtype",
[
(dtypes.Bool, pyarrow.bool_()),
(numpy_engine.Bool, pyarrow.bool_()),
(pandas_engine.BOOL, pyarrow.bool_()),
(dtypes.Int8, pyarrow.int8()),
(dtypes.Int16, pyarrow.int16()),
(dtypes.Int32, pyarrow.int32()),
(dtypes.Int64, pyarrow.int64()),
(numpy_engine.Int8, pyarrow.int8()),
(numpy_engine.Int16, pyarrow.int16()),
(numpy_engine.Int32, pyarrow.int32()),
(numpy_engine.Int64, pyarrow.int64()),
(pandas_engine.INT8, pyarrow.int8()),
(pandas_engine.INT16, pyarrow.int16()),
(pandas_engine.INT32, pyarrow.int32()),
(pandas_engine.INT64, pyarrow.int64()),
(dtypes.UInt8, pyarrow.uint8()),
(dtypes.UInt16, pyarrow.uint16()),
(dtypes.UInt32, pyarrow.uint32()),
(dtypes.UInt64, pyarrow.uint64()),
(numpy_engine.UInt8, pyarrow.uint8()),
(numpy_engine.UInt16, pyarrow.uint16()),
(numpy_engine.UInt32, pyarrow.uint32()),
(numpy_engine.UInt64, pyarrow.uint64()),
(pandas_engine.UINT8, pyarrow.uint8()),
(pandas_engine.UINT16, pyarrow.uint16()),
(pandas_engine.UINT32, pyarrow.uint32()),
(pandas_engine.UINT64, pyarrow.uint64()),
(dtypes.Float16, pyarrow.float16()),
(dtypes.Float32, pyarrow.float32()),
(dtypes.Float64, pyarrow.float64()),
(numpy_engine.Float16, pyarrow.float16()),
(numpy_engine.Float32, pyarrow.float32()),
(numpy_engine.Float64, pyarrow.float64()),
(pandas_engine.FLOAT32, pyarrow.float32()),
(pandas_engine.FLOAT64, pyarrow.float64()),
(dtypes.String, pyarrow.string()),
(numpy_engine.String, pyarrow.string()),
(pandas_engine.STRING, pyarrow.string()),
(pandas_engine.NpString, pyarrow.string()),
(numpy_engine.Object, pyarrow.string()),
(numpy_engine.Bytes, pyarrow.binary()),
(dtypes.Date, pyarrow.date64()),
(pandas_engine.Date, pyarrow.date64()),
(dtypes.Timestamp, pyarrow.timestamp("ns")),
(numpy_engine.DateTime64, pyarrow.date64()),
(pandas_engine.DateTime, pyarrow.timestamp("ns")),
(dtypes.Timedelta, pyarrow.duration("ns")),
(numpy_engine.Timedelta64, pyarrow.duration("ns")),
(
dtypes.Category(categories=["foo", "bar", "baz"], ordered=True),
pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=True),
),
],
)
@pytest.mark.parametrize("nullable", [True, False])
def test_to_pyarrow_field(
pandera_dtype: Type[dtypes.DataType],
nullable: bool,
expected_pyarrow_dtype: pyarrow.DataType,
):
"""Test if pandera_dtype is correctly converted to pyarrow dtype"""
name = "foo"

pandera_field = Column(pandera_dtype, nullable=nullable, name=name)
pyarrow_dtype = io.to_pyarrow_field(name, pandera_field)

assert pyarrow_dtype.type == expected_pyarrow_dtype
assert pyarrow_dtype.name == name
assert pyarrow_dtype.nullable == nullable


@pytest.mark.skipif(SKIP_PYARROW_TESTS, reason="pyarrow required")
@pytest.mark.parametrize(
"dataframe_schema, preserve_index, expected",
[
(
_create_schema("single"),
True,
pyarrow.schema(
[
pyarrow.field("__index_level_0__", pyarrow.int64(), False),
pyarrow.field("int_column", pyarrow.int64(), False),
pyarrow.field("float_column", pyarrow.float64(), False),
pyarrow.field("str_column", pyarrow.string(), False),
pyarrow.field(
"datetime_column", pyarrow.timestamp("ns"), False
),
pyarrow.field(
"timedelta_column", pyarrow.duration("ns"), False
),
pyarrow.field(
"optional_props_column", pyarrow.string(), True
),
]
),
),
(
_create_schema(None),
None,
pyarrow.schema(
[
pyarrow.field("int_column", pyarrow.int64(), False),
pyarrow.field("float_column", pyarrow.float64(), False),
pyarrow.field("str_column", pyarrow.string(), False),
pyarrow.field(
"datetime_column", pyarrow.timestamp("ns"), False
),
pyarrow.field(
"timedelta_column", pyarrow.duration("ns"), False
),
pyarrow.field(
"optional_props_column", pyarrow.string(), True
),
]
),
),
(
_create_schema("multi"),
None,
pyarrow.schema(
[
pyarrow.field("int_index0", pyarrow.int64(), False),
pyarrow.field("int_index1", pyarrow.int64(), False),
pyarrow.field("int_index2", pyarrow.int64(), False),
pyarrow.field("int_column", pyarrow.int64(), False),
pyarrow.field("float_column", pyarrow.float64(), False),
pyarrow.field("str_column", pyarrow.string(), False),
pyarrow.field(
"datetime_column", pyarrow.timestamp("ns"), False
),
pyarrow.field(
"timedelta_column", pyarrow.duration("ns"), False
),
pyarrow.field(
"optional_props_column", pyarrow.string(), True
),
]
),
),
(
_create_schema("multi"),
False,
pyarrow.schema(
[
pyarrow.field("int_column", pyarrow.int64(), False),
pyarrow.field("float_column", pyarrow.float64(), False),
pyarrow.field("str_column", pyarrow.string(), False),
pyarrow.field(
"datetime_column", pyarrow.timestamp("ns"), False
),
pyarrow.field(
"timedelta_column", pyarrow.duration("ns"), False
),
pyarrow.field(
"optional_props_column", pyarrow.string(), True
),
]
),
),
(
_create_schema_python_types(),
None,
pyarrow.schema(
[
pyarrow.field("int_column", pyarrow.int64(), False),
pyarrow.field("float_column", pyarrow.float64(), False),
pyarrow.field("str_column", pyarrow.string(), False),
pyarrow.field("object_column", pyarrow.string(), False),
]
),
),
],
)
def test_to_pyarrow_schema(
dataframe_schema: pandera.schemas.DataFrameSchema,
preserve_index: bool,
expected: pyarrow.Schema,
):
"""Test if pandera schema is correctly converted to pyarrow.Schema"""

# Drop column with no dtype specified
dataframe_schema.columns = {
k: v
for k, v in dataframe_schema.columns.items()
if k != "notype_column"
}

pyarrow_schema = io.to_pyarrow_schema(dataframe_schema, preserve_index)
assert expected.equals(pyarrow_schema)