Skip to content

Commit

Permalink
Add support for Spark Connect dataframes (#1775)
Browse files Browse the repository at this point in the history
* Add minimal support for connect_dfs, without changing all type annotations

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* Change pyspark dependency and parameterize unit tests to run both common and connect spark dataframes

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* clean requirements and small typos

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* fix pylint

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

* remove annotations from nox outputs

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>

---------

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>
  • Loading branch information
filipeo2-mck authored Aug 15, 2024
1 parent f6317d6 commit d04bb3a
Show file tree
Hide file tree
Showing 37 changed files with 471 additions and 7,820 deletions.
422 changes: 5 additions & 417 deletions ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt

Large diffs are not rendered by default.

425 changes: 5 additions & 420 deletions ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt

Large diffs are not rendered by default.

426 changes: 4 additions & 422 deletions ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt

Large diffs are not rendered by default.

429 changes: 4 additions & 425 deletions ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt

Large diffs are not rendered by default.

404 changes: 5 additions & 399 deletions ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt

Large diffs are not rendered by default.

407 changes: 5 additions & 402 deletions ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt

Large diffs are not rendered by default.

408 changes: 4 additions & 404 deletions ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt

Large diffs are not rendered by default.

411 changes: 4 additions & 407 deletions ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt

Large diffs are not rendered by default.

444 changes: 5 additions & 439 deletions ci/requirements-py3.8-pandas1.5.3-pydantic1.10.11.txt

Large diffs are not rendered by default.

448 changes: 5 additions & 443 deletions ci/requirements-py3.8-pandas1.5.3-pydantic2.3.0.txt

Large diffs are not rendered by default.

428 changes: 5 additions & 423 deletions ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt

Large diffs are not rendered by default.

431 changes: 5 additions & 426 deletions ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt

Large diffs are not rendered by default.

432 changes: 4 additions & 428 deletions ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt

Large diffs are not rendered by default.

435 changes: 4 additions & 431 deletions ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt

Large diffs are not rendered by default.

429 changes: 4 additions & 425 deletions dev/requirements-3.10.txt

Large diffs are not rendered by default.

411 changes: 4 additions & 407 deletions dev/requirements-3.11.txt

Large diffs are not rendered by default.

449 changes: 5 additions & 444 deletions dev/requirements-3.8.txt

Large diffs are not rendered by default.

435 changes: 4 additions & 431 deletions dev/requirements-3.9.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- pandas-stubs

# pyspark extra
- pyspark >= 3.2.0
- pyspark[connect] >= 3.2.0

# polars extra
- polars >= 0.20.0
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def ci_requirements(session: Session, pandas: str, pydantic: str) -> None:
_ci_requirement_file_name(session, pandas, pydantic),
"--no-header",
"--upgrade",
"--no-annotate",
)


Expand All @@ -379,6 +380,7 @@ def dev_requirements(session: Session) -> None:
output_file,
"--no-header",
"--upgrade",
"--no-annotate",
)


Expand Down
20 changes: 19 additions & 1 deletion pandera/accessors/pyspark_sql_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import warnings
from typing import Optional
from packaging import version

import pyspark
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.pyspark.container import DataFrameSchema

Expand Down Expand Up @@ -104,7 +106,7 @@ def decorator(accessor):

def register_dataframe_accessor(name):
"""
Register a custom accessor with a DataFrame
Register a custom accessor with a classical Spark DataFrame
:param name: name used when calling the accessor after its registered
:returns: a class decorator callable.
Expand All @@ -115,6 +117,19 @@ def register_dataframe_accessor(name):
return _register_accessor(name, DataFrame)


def register_connect_dataframe_accessor(name):
"""
Register a custom accessor with a Spark Connect DataFrame
:param name: name used when calling the accessor after its registered
:returns: a class decorator callable.
"""
# pylint: disable=import-outside-toplevel
from pyspark.sql.connect.dataframe import DataFrame as psc_DataFrame

return _register_accessor(name, psc_DataFrame)


class PanderaDataFrameAccessor(PanderaAccessor):
"""Pandera accessor for pyspark DataFrame."""

Expand All @@ -127,3 +142,6 @@ def check_schema_type(schema):


register_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
# Handle optional Spark Connect imports for pyspark>=3.4 (if available)
if version.parse(pyspark.__version__) >= version.parse("3.4"):
register_connect_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
19 changes: 17 additions & 2 deletions pandera/api/pyspark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,27 @@

from functools import lru_cache
from typing import List, NamedTuple, Tuple, Type, Union
from numpy import bool_ as np_bool
from packaging import version

import pyspark.sql.types as pst
from pyspark.sql import DataFrame

import pyspark
from pandera.api.checks import Check
from pandera.dtypes import DataType

# pylint: disable=reimported
# Handles optional Spark Connect imports for pyspark>=3.4 (if available)
if version.parse(pyspark.__version__) >= version.parse("3.4"):
from pyspark.sql.connect.dataframe import DataFrame as psc_DataFrame
else:
from pyspark.sql import (
DataFrame as psc_DataFrame,
)

DataFrameTypes = Union[DataFrame, psc_DataFrame]

CheckList = Union[Check, List[Check]]

PysparkDefaultTypes = Union[
Expand Down Expand Up @@ -57,7 +71,7 @@
class PysparkDataframeColumnObject(NamedTuple):
"""Pyspark Object which holds dataframe and column value in a named tuble"""

dataframe: DataFrame
dataframe: DataFrameTypes
column_name: str


Expand All @@ -69,6 +83,7 @@ def supported_types() -> SupportedTypes:

try:
table_types.append(DataFrame)
table_types.append(psc_DataFrame)

except ImportError: # pragma: no cover
pass
Expand All @@ -89,4 +104,4 @@ def is_table(obj):

def is_bool(x):
"""Verifies whether an object is a boolean type."""
return isinstance(x, (bool, type(pst.BooleanType())))
return isinstance(x, (bool, type(pst.BooleanType()), np_bool))
5 changes: 3 additions & 2 deletions pandera/backends/pyspark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
scalar_failure_case,
)
from pandera.errors import FailureCaseMetadata, SchemaError, SchemaWarning
from pandera.api.pyspark.types import DataFrameTypes


class ColumnInfo(NamedTuple):
Expand All @@ -34,7 +35,7 @@ class ColumnInfo(NamedTuple):
lazy_exclude_column_names: List


FieldCheckObj = Union[col, DataFrame]
FieldCheckObj = Union[col, DataFrameTypes]

T = TypeVar(
"T",
Expand All @@ -50,7 +51,7 @@ class PysparkSchemaBackend(BaseSchemaBackend):

def subsample(
self,
check_obj: DataFrame,
check_obj: DataFrameTypes,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[float] = None,
Expand Down
97 changes: 26 additions & 71 deletions pandera/backends/pyspark/checks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Check backend for pyspark."""

from functools import partial
from typing import Dict, List, Optional

from multimethod import DispatchError, overload
from pyspark.sql import DataFrame
from typing import Dict, List, Optional, Union

from pandera.api.base.checks import CheckResult, GroupbyObject
from pandera.api.checks import Check
Expand All @@ -14,6 +11,7 @@
is_table,
)
from pandera.backends.base import BaseCheckBackend
from pandera.api.pyspark.types import DataFrameTypes


class PySparkCheckBackend(BaseCheckBackend):
Expand All @@ -26,7 +24,7 @@ def __init__(self, check: Check):
self.check = check
self.check_fn = partial(check._check_fn, **check._check_kwargs)

def groupby(self, check_obj: DataFrame): # pragma: no cover
def groupby(self, check_obj: DataFrameTypes): # pragma: no cover
"""Implements groupby behavior for check object."""
assert self.check.groupby is not None, "Check.groupby must be set."
if isinstance(self.check.groupby, (str, list)):
Expand All @@ -45,61 +43,34 @@ def aggregate(self, check_obj):
def _format_groupby_input(
groupby_obj: GroupbyObject,
groups: Optional[List[str]],
) -> Dict[str, DataFrame]: # pragma: no cover
) -> Dict[str, DataFrameTypes]: # pragma: no cover
raise NotImplementedError

@overload # type: ignore [no-redef]
def preprocess(
self,
check_obj: DataFrame,
check_obj: DataFrameTypes,
key: str, # type: ignore [valid-type]
) -> DataFrame:
) -> DataFrameTypes:
return check_obj

# Workaround for multimethod not supporting Optional arguments
# such as `key: Optional[str]` (fails in multimethod)
# https://github.com/coady/multimethod/issues/90
# FIXME when the multimethod supports Optional args # pylint: disable=fixme
@overload # type: ignore [no-redef]
def preprocess(
def apply(
self,
check_obj: DataFrame, # type: ignore [valid-type]
) -> DataFrame:
return check_obj

@overload
def apply(self, check_obj):
"""Apply the check function to a check object."""
raise NotImplementedError

@overload # type: ignore [no-redef]
def apply(self, check_obj: DataFrame):
return self.check_fn(check_obj) # pragma: no cover

@overload # type: ignore [no-redef]
def apply(self, check_obj: is_table): # type: ignore [valid-type]
return self.check_fn(check_obj) # pragma: no cover

@overload # type: ignore [no-redef]
def apply(self, check_obj: DataFrame, column_name: str, kwargs: dict): # type: ignore [valid-type]
# kwargs['column_name'] = column_name
# return self.check._check_fn(check_obj, *list(kwargs.values()))
check_obj_and_col_name = PysparkDataframeColumnObject(
check_obj, column_name
)
return self.check._check_fn(check_obj_and_col_name, **kwargs)
check_obj: Union[DataFrameTypes, is_table],
column_name: str = None,
kwargs: dict = None,
):
if column_name and kwargs:
check_obj_and_col_name = PysparkDataframeColumnObject(
check_obj, column_name
)
return self.check._check_fn(check_obj_and_col_name, **kwargs)

@overload
def postprocess(self, check_obj, check_output):
"""Postprocesses the result of applying the check function."""
raise TypeError( # pragma: no cover
f"output type of check_fn not recognized: {type(check_output)}"
)
else:
return self.check_fn(check_obj) # pragma: no cover

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj,
check_obj: DataFrameTypes,
check_output: is_bool, # type: ignore [valid-type]
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
Expand All @@ -112,29 +83,13 @@ def postprocess(

def __call__(
self,
check_obj: DataFrame,
check_obj: DataFrameTypes,
key: Optional[str] = None,
) -> CheckResult:
if key is None:
# pylint:disable=no-value-for-parameter
check_obj = self.preprocess(check_obj)
else:
check_obj = self.preprocess(check_obj, key)

try:
if key is None:
check_output = self.apply(check_obj)
else:
check_output = (
self.apply( # pylint:disable=too-many-function-args
check_obj, key, self.check._check_kwargs
)
)

except DispatchError as exc: # pragma: no cover
if exc.__cause__ is not None:
raise exc.__cause__
raise exc
except TypeError as err:
raise err
check_obj = self.preprocess(check_obj, key)

check_output = self.apply( # pylint:disable=too-many-function-args
check_obj, key, self.check._check_kwargs
)

return self.postprocess(check_obj, check_output)
12 changes: 0 additions & 12 deletions pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,6 @@ def unique(

return check_obj

def _check_uniqueness(
self,
obj: DataFrame,
schema,
) -> DataFrame:
"""Ensure uniqueness in dataframe columns.
:param obj: dataframe to check.
:param schema: schema object.
:returns: dataframe checked.
"""

##########
# Checks #
##########
Expand Down
24 changes: 19 additions & 5 deletions pandera/backends/pyspark/register.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Register pyspark backends."""

from functools import lru_cache
from packaging import version

import pyspark.sql as pst
import pyspark
import pyspark.sql as ps

# Handles optional Spark Connect imports for pyspark>=3.4 (if available)
CURRENT_PYSPARK_VERSION = version.parse(pyspark.__version__)
if CURRENT_PYSPARK_VERSION >= version.parse("3.4"):
from pyspark.sql.connect import dataframe as psc


@lru_cache
Expand All @@ -28,7 +35,14 @@ def register_pyspark_backends():
from pandera.backends.pyspark.components import ColumnBackend
from pandera.backends.pyspark.container import DataFrameSchemaBackend

Check.register_backend(pst.DataFrame, PySparkCheckBackend)
ColumnSchema.register_backend(pst.DataFrame, ColumnSchemaBackend)
Column.register_backend(pst.DataFrame, ColumnBackend)
DataFrameSchema.register_backend(pst.DataFrame, DataFrameSchemaBackend)
# Register classical DataFrame
Check.register_backend(ps.DataFrame, PySparkCheckBackend)
ColumnSchema.register_backend(ps.DataFrame, ColumnSchemaBackend)
Column.register_backend(ps.DataFrame, ColumnBackend)
DataFrameSchema.register_backend(ps.DataFrame, DataFrameSchemaBackend)
# Register Spark Connect DataFrame, if available
if CURRENT_PYSPARK_VERSION >= version.parse("3.4"):
Check.register_backend(psc.DataFrame, PySparkCheckBackend)
ColumnSchema.register_backend(psc.DataFrame, ColumnSchemaBackend)
Column.register_backend(psc.DataFrame, ColumnBackend)
DataFrameSchema.register_backend(psc.DataFrame, DataFrameSchemaBackend)
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pyarrow
pydantic
multimethod <= 1.10.0
pandas-stubs
pyspark >= 3.2.0
pyspark[connect] >= 3.2.0
polars >= 0.20.0
modin
protobuf
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"strategies": ["hypothesis >= 6.92.7"],
"hypotheses": ["scipy"],
"io": ["pyyaml >= 5.1", "black", "frictionless <= 4.40.8"],
"pyspark": ["pyspark >= 3.2.0"],
"pyspark": ["pyspark[connect] >= 3.2.0"],
"modin": ["modin", "ray", "dask[dataframe]"],
"modin-ray": ["modin", "ray"],
"modin-dask": ["modin", "dask[dataframe]"],
Expand Down
17 changes: 16 additions & 1 deletion tests/pyspark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# pylint:disable=redefined-outer-name
import datetime
import os

import pyspark.sql.types as T
import pytest
Expand All @@ -15,7 +16,21 @@ def spark() -> SparkSession:
"""
creates spark session
"""
return SparkSession.builder.getOrCreate()
spark: SparkSession = SparkSession.builder.getOrCreate()
yield spark
spark.stop()


@pytest.fixture(scope="session")
def spark_connect() -> SparkSession:
"""
creates spark connection session
"""
# Set location of localhost Spark Connect server
os.environ["SPARK_LOCAL_REMOTE"] = "sc://localhost"
spark: SparkSession = SparkSession.builder.getOrCreate()
yield spark
spark.stop()


@pytest.fixture(scope="session")
Expand Down
Loading

0 comments on commit d04bb3a

Please sign in to comment.