Skip to content

Commit

Permalink
Implement minimal built-in checks for Ibis backend
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
deepyaman committed Dec 26, 2024
1 parent d846e0e commit 3d0efc6
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 4 deletions.
7 changes: 7 additions & 0 deletions pandera/api/ibis/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandera.api.ibis.types import IbisDtypeInputTypes
from pandera.backends.ibis.register import register_ibis_backends
from pandera.engines import ibis_engine
from pandera.utils import is_regex


class Column(ComponentSchema[ir.Table]):
Expand Down Expand Up @@ -109,6 +110,12 @@ def dtype(self):
def dtype(self, value) -> None:
self._dtype = ibis_engine.Engine.dtype(value) if value else None

@property
def selector(self):
if self.name is not None and not is_regex(self.name) and self.regex:
return f"^{self.name}$"
return self.name

def set_name(self, name: str):
"""Set or modify the name of a column object.
Expand Down
26 changes: 26 additions & 0 deletions pandera/backends/ibis/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Built-in checks for Ibis."""

from typing import Any, TypeVar

import ibis
import ibis.expr.types as ir

from pandera.api.extensions import register_builtin_check
from pandera.api.ibis.types import IbisData

T = TypeVar("T")


@register_builtin_check(
aliases=["eq"],
error="equal_to({value})",
)
def equal_to(data: IbisData, value: Any) -> ir.Table:
"""Ensure all elements of a data container equal a certain value.
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param value: values in this polars data structure must be
equal to this value.
"""
return data.table[data.key] == value
1 change: 1 addition & 0 deletions pandera/backends/ibis/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ def register_ibis_backends():
DataFrameSchema.register_backend(ir.Table, DataFrameSchemaBackend)
Column.register_backend(ir.Table, ColumnBackend)
Check.register_backend(ir.Table, IbisCheckBackend)
Check.register_backend(ir.Column, IbisCheckBackend)
2 changes: 1 addition & 1 deletion pandera/backends/pandas/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Pandas implementation of built-in checks"""
"""Built-in checks for pandas."""

import operator
import re
Expand Down
2 changes: 1 addition & 1 deletion pandera/backends/pyspark/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""PySpark implementation of built-in checks"""
"""Built-in checks for PySpark."""

from typing import Any, Iterable, TypeVar

Expand Down
173 changes: 173 additions & 0 deletions tests/ibis/test_ibis_builtin_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Unit tests for Ibis checks."""

import ibis
import ibis.expr.datatypes as dt
import pytest

import pandera.ibis as pa
from pandera.ibis import Column, DataFrameSchema
from pandera.backends.ibis.register import register_ibis_backends


@pytest.fixture(autouse=True)
def _register_ibis_backends():
register_ibis_backends()


class BaseClass:
"""This is the base class for the all the test cases class"""

def __int__(self, params=None):
pass

@staticmethod
def convert_value(sample_data, conversion_datatype):
"""
Convert the sample data to other formats excluding dates and does not
support complex datatypes such as array and map as of now
"""

data_dict = {}
for key, value in sample_data.items():
if key == "test_expression":
if not isinstance(value, list):
data_dict[key] = conversion_datatype(value)
else:
data_dict[key] = [conversion_datatype(i) for i in value]

else:
if not isinstance(value[0][1], list):
data_dict[key] = [
(i[0], conversion_datatype(i[1])) for i in value
]
else:
final_val = []
for row in value:
data_val = []
for column in row[1]:
data_val.append(conversion_datatype(column))
final_val.append((row[0], data_val))
data_dict[key] = final_val
return data_dict

@staticmethod
def convert_data(sample_data, convert_type):
"""
Convert the numeric data to required format
"""
if convert_type in ("float32", "float64"):
data_dict = BaseClass.convert_value(sample_data, float)

if convert_type == "decimal":
data_dict = BaseClass.convert_value(sample_data, decimal.Decimal)

if convert_type == "date":
data_dict = BaseClass.convert_value(
sample_data, methodcaller("date")
)

if convert_type == "time":
data_dict = BaseClass.convert_value(
sample_data, methodcaller("time")
)

if convert_type == "binary":
data_dict = BaseClass.convert_value(
sample_data, methodcaller("encode")
)

return data_dict

@staticmethod
def check_function(
check_fn,
pass_case_data,
fail_case_data,
data_types,
function_args,
fail_on_init=False,
init_exception_cls=None,
):
"""
This function does performs the actual validation
"""
if fail_on_init:
with pytest.raises(init_exception_cls):
check_fn(*function_args)
return

schema = DataFrameSchema(
{
"product": Column(dt.String),
"code": (
Column(data_types, check_fn(*function_args))
if isinstance(function_args, tuple)
else Column(data_types, check_fn(function_args))
),
}
)

ibis_schema = ibis.schema({"product": "string", "code": data_types})

# check that check on pass case data passes
t = ibis.memtable(pass_case_data, schema=ibis_schema)
schema.validate(t)

with pytest.raises(SchemaError):
t = ibis.memtable(fail_case_data, schema=ibis_schema)
schema.validate(t)


class TestEqualToCheck(BaseClass):
sample_numeric_data = {
"test_pass_data": [("foo", 30), ("bar", 30)],
"test_fail_data": [("foo", 30), ("bar", 31)],
"test_expression": 30,
}

sample_string_data = {
"test_pass_data": [("foo", "a"), ("bar", "a")],
"test_fail_data": [("foo", "a"), ("bar", "b")],
"test_expression": "a",
}

def pytest_generate_tests(self, metafunc):
"""This function passes the parameter for each function based on parameter form get_data_param function"""
# called once per each test function
funcarglist = self.get_data_param()[metafunc.function.__name__]
argnames = sorted(funcarglist[0])
metafunc.parametrize(
argnames,
[
[funcargs[name] for name in argnames]
for funcargs in funcarglist
],
)

def get_data_param(self):
"""Generate the params which will be used to test this function. All the acceptable
data types would be tested"""
return {
"test_equal_to_check": [
{"datatype": dt.Int32, "data": self.sample_numeric_data},
{"datatype": dt.Int64, "data": self.sample_numeric_data},
{"datatype": dt.String, "data": self.sample_string_data},
{
"datatype": dt.Float64,
"data": self.convert_data(
self.sample_numeric_data, "float64"
),
},
]
}

@pytest.mark.parametrize("check_fn", [pa.Check.equal_to, pa.Check.eq])
def test_equal_to_check(self, check_fn, datatype, data) -> None:
"""Test the Check to see if all the values are equal to defined value"""
self.check_function(
check_fn,
data["test_pass_data"],
data["test_fail_data"],
datatype,
data["test_expression"],
)
4 changes: 2 additions & 2 deletions tests/polars/test_polars_builtin_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __int__(self, params=None):
}

def pytest_generate(self, metafunc):
"""This function passes the parameter for each function based on parameter form get_data_param function"""
"""This function passes the parameter for each function based on parameter from get_data_param function"""
raise NotImplementedError

@staticmethod
Expand Down Expand Up @@ -230,7 +230,7 @@ def pytest_generate_tests(self, metafunc):
)

def get_data_param(self):
"""Generate the params which will be used to test this function. All the accpetable
"""Generate the params which will be used to test this function. All the acceptable
data types would be tested"""
return {
"test_equal_to_check": [
Expand Down

0 comments on commit 3d0efc6

Please sign in to comment.