Skip to content

Commit

Permalink
handle deprecated methods/arguments in polars v1 (#1746)
Browse files Browse the repository at this point in the history
* handle deprecated methods/arguments in polars v1

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* update ci

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* update ci

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

---------

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy authored Jul 17, 2024
1 parent b86978b commit b8604b3
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]
pandas-version: ["2.2.2"]
pydantic-version: ["2.3.0"]
polars-version: ["0.20.31", "1.0.0"]
polars-version: ["0.20.31", "1.2.0"]
extra:
- hypotheses
- io
Expand Down
25 changes: 25 additions & 0 deletions pandera/api/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
# pylint: disable=cyclic-import
"""Polars validation engine utilities."""

from typing import Dict, List

import polars as pl

from pandera.api.polars.types import PolarsCheckObjects
from pandera.engines.polars_engine import polars_version
from pandera.config import (
ValidationDepth,
get_config_context,
get_config_global,
)


def get_lazyframe_schema(lf: pl.LazyFrame) -> Dict[str, pl.DataType]:
"""Get a dict of column names and dtypes from a polars LazyFrame."""
if polars_version().release >= (1, 0, 0):
return lf.collect_schema()
return lf.schema


def get_lazyframe_column_dtypes(lf: pl.LazyFrame) -> List[pl.DataType]:
"""Get a list of column dtypes from a polars LazyFrame."""
if polars_version().release >= (1, 0, 0):
return lf.collect_schema().dtypes()
return [*lf.schema.values()]


def get_lazyframe_column_names(lf: pl.LazyFrame) -> List[str]:
"""Get a list of column names from a polars LazyFrame."""
if polars_version().release >= (1, 0, 0):
return lf.collect_schema().names()
return lf.columns


def get_validation_depth(check_obj: PolarsCheckObjects) -> ValidationDepth:
"""Get validation depth for a given polars check object."""
is_dataframe = isinstance(check_obj, pl.DataFrame)
Expand Down
18 changes: 13 additions & 5 deletions pandera/backends/polars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pandera.api.base.error_handler import ErrorHandler
from pandera.api.polars.types import CheckResult
from pandera.api.polars.utils import get_lazyframe_column_dtypes
from pandera.backends.base import BaseSchemaBackend, CoreCheckResult
from pandera.constants import CHECK_OUTPUT_KEY
from pandera.errors import (
Expand All @@ -21,8 +22,10 @@
def is_float_dtype(check_obj: pl.LazyFrame, selector):
"""Check if a column/selector is a float."""
return all(
dtype in pl.FLOAT_DTYPES
for dtype in check_obj.select(pl.col(selector)).schema.values()
dtype in {pl.Float32, pl.Float64}
for dtype in get_lazyframe_column_dtypes(
check_obj.select(pl.col(selector))
)
)


Expand Down Expand Up @@ -155,9 +158,14 @@ def failure_cases_metadata(
failure_cases_df = err.failure_cases

# get row number of the failure cases
index = err.check_output.with_row_count("index").filter(
pl.col(CHECK_OUTPUT_KEY).eq(False)
)["index"]
if hasattr(err.check_output, "with_row_index"):
_index_lf = err.check_output.with_row_index("index")
else:
_index_lf = err.check_output.with_row_count("index")

index = _index_lf.filter(pl.col(CHECK_OUTPUT_KEY).eq(False))[
"index"
]
if len(err.failure_cases.columns) > 1:
# for boolean dataframe check results, reduce failure cases
# to a struct column
Expand Down
18 changes: 13 additions & 5 deletions pandera/backends/polars/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from pandera.api.base.checks import CheckResult
from pandera.api.checks import Check
from pandera.api.polars.types import PolarsData
from pandera.api.polars.utils import (
get_lazyframe_schema,
get_lazyframe_column_names,
)
from pandera.backends.base import BaseCheckBackend
from pandera.constants import CHECK_OUTPUT_KEY

Expand Down Expand Up @@ -55,7 +59,7 @@ def apply(self, check_obj: PolarsData):
if isinstance(out, bool):
return out

if len(out.columns) > 1:
if len(get_lazyframe_schema(out)) > 1:
# for checks that return a boolean dataframe, reduce to a single
# boolean column.
out = out.select(
Expand All @@ -66,7 +70,11 @@ def apply(self, check_obj: PolarsData):
).alias(CHECK_OUTPUT_KEY)
)
else:
out = out.select(pl.col(out.columns[0]).alias(CHECK_OUTPUT_KEY))
out = out.select(
pl.col(get_lazyframe_column_names(out)[0]).alias(
CHECK_OUTPUT_KEY
)
)

return out

Expand All @@ -86,9 +94,9 @@ def postprocess(
"""Postprocesses the result of applying the check function."""
results = pl.LazyFrame(check_output.collect())
passed = results.select([pl.col(CHECK_OUTPUT_KEY).all()])
failure_cases = check_obj.lazyframe.with_context(results).filter(
pl.col(CHECK_OUTPUT_KEY).not_()
)
failure_cases = pl.concat(
[check_obj.lazyframe, results], how="horizontal"
).filter(pl.col(CHECK_OUTPUT_KEY).not_())

if check_obj.key is not None:
failure_cases = failure_cases.select(check_obj.key)
Expand Down
13 changes: 9 additions & 4 deletions pandera/backends/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.polars.components import Column
from pandera.api.polars.types import PolarsData
from pandera.api.polars.utils import (
get_lazyframe_schema,
get_lazyframe_column_names,
)
from pandera.backends.base import CoreCheckResult
from pandera.backends.polars.base import PolarsSchemaBackend, is_float_dtype
from pandera.config import ValidationDepth, ValidationScope, get_config_context
Expand Down Expand Up @@ -100,7 +104,7 @@ def validate(
return check_obj

def get_regex_columns(self, schema, check_obj) -> Iterable:
return check_obj.select(pl.col(schema.selector)).columns
return get_lazyframe_schema(check_obj.select(pl.col(schema.selector)))

def run_checks_and_handle_errors(
self,
Expand Down Expand Up @@ -214,7 +218,7 @@ def check_nullable(
isna = check_obj.select(expr)
passed = isna.select([pl.col("*").all()]).collect()
results = []
for column in isna.columns:
for column in get_lazyframe_column_names(isna):
if passed.select(column).item():
continue
failure_cases = (
Expand Down Expand Up @@ -326,8 +330,9 @@ def check_dtype(

results = []
check_obj_subset = check_obj.select(schema.selector)
for column in check_obj_subset.columns:
obj_dtype = check_obj_subset.schema[column]
for column, obj_dtype in get_lazyframe_schema(
check_obj_subset
).items():
results.append(
CoreCheckResult(
passed=schema.dtype.check(
Expand Down
13 changes: 8 additions & 5 deletions pandera/backends/polars/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.polars.container import DataFrameSchema
from pandera.api.polars.types import PolarsData
from pandera.api.polars.utils import get_lazyframe_column_names
from pandera.backends.base import ColumnInfo, CoreCheckResult
from pandera.backends.polars.base import PolarsSchemaBackend
from pandera.config import ValidationDepth, ValidationScope, get_config_context
Expand Down Expand Up @@ -211,7 +212,7 @@ def collect_column_info(self, check_obj: pl.LazyFrame, schema):
for col_name, col_schema in schema.columns.items():
if (
not col_schema.regex
and col_name not in check_obj.columns
and col_name not in get_lazyframe_column_names(check_obj)
and col_schema.required
):
absent_column_names.append(col_name)
Expand All @@ -226,11 +227,11 @@ def collect_column_info(self, check_obj: pl.LazyFrame, schema):
regex_match_patterns.append(col_schema.selector)
except SchemaError:
pass
elif col_name in check_obj.columns:
elif col_name in get_lazyframe_column_names(check_obj):
column_names.append(col_name)

# drop adjacent duplicated column names
destuttered_column_names = [*check_obj.columns]
destuttered_column_names = [*get_lazyframe_column_names(check_obj)]

return ColumnInfo(
sorted_column_names=dict.fromkeys(column_names),
Expand All @@ -256,7 +257,7 @@ def collect_schema_components(
from pandera.api.polars.components import Column

columns = {}
for col in check_obj.columns:
for col in get_lazyframe_column_names(check_obj):
columns[col] = Column(schema.dtype, name=str(col))

schema_components = []
Expand Down Expand Up @@ -579,7 +580,9 @@ def check_column_values_are_unique(
)

for lst in temp_unique:
subset = [x for x in lst if x in check_obj.columns]
subset = [
x for x in lst if x in get_lazyframe_column_names(check_obj)
]
duplicates = check_obj.select(subset).collect().is_duplicated()
if duplicates.any():
failure_cases = check_obj.filter(duplicates)
Expand Down
30 changes: 23 additions & 7 deletions pandera/engines/polars_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import warnings
from typing import (
Any,
Dict,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Expand All @@ -19,7 +21,6 @@
import polars as pl
from packaging import version
from polars.datatypes import DataTypeClass
from polars.type_aliases import SchemaDict

from pandera import dtypes, errors
from pandera.api.polars.types import PolarsData
Expand All @@ -37,6 +38,9 @@
)


SchemaDict = Mapping[str, PolarsDataType]


def polars_version() -> version.Version:
"""Return the polars version."""

Expand Down Expand Up @@ -167,6 +171,8 @@ def try_coerce(self, data_container: PolarsDataContainer) -> pl.LazyFrame:
raises a :class:`~pandera.errors.ParserError` if the coercion fails
:raises: :class:`~pandera.errors.ParserError`: if coercion fails
"""
from pandera.api.polars.utils import get_lazyframe_schema

if isinstance(data_container, pl.LazyFrame):
data_container = PolarsData(data_container)

Expand All @@ -187,7 +193,7 @@ def try_coerce(self, data_container: PolarsDataContainer) -> pl.LazyFrame:
failure_cases = failure_cases.select(data_container.key)
raise errors.ParserError(
f"Could not coerce {_key} LazyFrame with schema "
f"{data_container.lazyframe.schema} "
f"{get_lazyframe_schema(data_container.lazyframe)} "
f"into type {self.type}",
failure_cases=failure_cases,
parser_output=is_coercible,
Expand Down Expand Up @@ -561,16 +567,26 @@ class Array(DataType):
def __init__( # pylint:disable=super-init-not-called
self,
inner: Optional[PolarsDataType] = None,
shape: Union[int, Tuple[int, ...], None] = None,
*,
width: Optional[int] = None,
) -> None:
if inner or width:
object.__setattr__(
self, "type", pl.Array(inner=inner, width=width)
)

kwargs: Dict[str, Union[int, Tuple[int, ...]]] = {}
if width is not None:
kwargs["shape"] = width
elif shape is not None:
kwargs["shape"] = shape

if inner or shape or width:
object.__setattr__(self, "type", pl.Array(inner=inner, **kwargs))

@classmethod
def from_parametrized_dtype(cls, polars_dtype: pl.Array):
return cls(inner=polars_dtype.inner, width=polars_dtype.width)
return cls(
inner=polars_dtype.inner,
shape=polars_dtype.shape,
)


@Engine.register_dtype(equivalents=[pl.List])
Expand Down
3 changes: 2 additions & 1 deletion tests/polars/test_polars_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import pandera.polars as pa
from pandera.api.polars.utils import get_lazyframe_schema
from pandera.constants import CHECK_OUTPUT_KEY


Expand Down Expand Up @@ -145,7 +146,7 @@ def test_polars_element_wise_dataframe_check(lf):
validated_data = schema.validate(lf)
assert validated_data.collect().equals(lf.collect())

for col in lf.columns:
for col in get_lazyframe_schema(lf):
invalid_lf = lf.with_columns(**{col: pl.Series([-1, 2, -4, 3])})
try:
schema.validate(invalid_lf)
Expand Down
5 changes: 3 additions & 2 deletions tests/polars/test_polars_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import pandera.polars as pa
from pandera.api.polars.utils import get_lazyframe_schema
from pandera.backends.base import CoreCheckResult
from pandera.backends.polars.components import ColumnBackend
from pandera.dtypes import DataType
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_get_regex_columns(kwargs):
backend = ColumnBackend()
data = pl.DataFrame({f"col_{i}": [1, 2, 3] for i in range(10)}).lazy()
matched_columns = backend.get_regex_columns(column_schema, data)
assert matched_columns == data.columns
assert matched_columns == get_lazyframe_schema(data)

no_match_data = data.rename(
lambda c: c.replace(
Expand All @@ -92,7 +93,7 @@ def test_get_regex_columns(kwargs):
)
)
matched_columns = backend.get_regex_columns(column_schema, no_match_data)
assert matched_columns == []
assert len(matched_columns) == 0


@pytest.mark.parametrize(
Expand Down
7 changes: 4 additions & 3 deletions tests/polars/test_polars_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pandera.polars as pa
from pandera.engines.polars_engine import polars_version
from pandera.api.polars.utils import get_lazyframe_schema
from pandera.api.base.error_handler import ErrorCategory
from pandera.config import (
CONFIG,
Expand Down Expand Up @@ -157,12 +158,12 @@ def test_coerce_validation_depth_none(validation_depth_none, schema):
# simply calling validation shouldn't raise a coercion error, since we're
# casting the types lazily
validated_data = schema.validate(data)
assert validated_data.schema["a"] == pl.Int64
assert get_lazyframe_schema(validated_data)["a"] == pl.Int64

ErrorCls = (
pl.InvalidOperationError
pl.exceptions.InvalidOperationError
if polars_version().release >= (1, 0, 0)
else pl.ComputeError
else pl.exceptions.ComputeError
)
with pytest.raises(ErrorCls):
validated_data.collect()
Expand Down
Loading

0 comments on commit b8604b3

Please sign in to comment.