Skip to content

Commit

Permalink
Updated Spark 3.3 dependency
Browse files Browse the repository at this point in the history
This commit updates the Spark 3.3 dependency of Deequ. There are some breaking changes to the Scala APIs, from a Py4J perspective. In order to work around that, we use the Spark version to switch between the updated API and the old API. This is not sustainable and will be revisited in a future PR, or via a different release mechanism. The issue is that we have multiple branches for multiple Spark versions in Deequ, but only one branch in PyDeequ.

The changes were verified by running the tests in Docker against Spark version 3.3. The docker file was also updated so that it copies over the pyproject.toml file and installs dependencies in a separate layer, before the code is copied. This allows for fast iteration of the code, without the need to install dependencies every time the docker image is built.
  • Loading branch information
rdsharma26 committed Apr 11, 2024
1 parent 4bb727b commit 0cdb2db
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 19 deletions.
10 changes: 6 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ RUN pip3 --version
RUN java -version
RUN pip install poetry==1.7.1

COPY . /python-deequ
RUN mkdir python-deequ
COPY pyproject.toml /python-deequ
COPY poetry.lock /python-deequ
WORKDIR python-deequ

RUN poetry lock --no-update
RUN poetry install
RUN poetry add pyspark==3.3
RUN poetry install -vvv
RUN poetry add pyspark==3.3 -vvv

ENV SPARK_VERSION=3.3
COPY . /python-deequ
CMD poetry run python -m pytest -s tests
62 changes: 52 additions & 10 deletions pydeequ/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydeequ.repository import MetricsRepository, ResultKey
from enum import Enum
from pydeequ.scala_utils import to_scala_seq

from pydeequ.configs import SPARK_VERSION

class _AnalyzerObject:
"""
Expand Down Expand Up @@ -303,7 +303,19 @@ def _analyzer_jvm(self):
:return self
"""
return self._deequAnalyzers.Compliance(self.instance, self.predicate, self._jvm.scala.Option.apply(self.where))
if SPARK_VERSION == "3.3":
return self._deequAnalyzers.Compliance(
self.instance,
self.predicate,
self._jvm.scala.Option.apply(self.where),
self._jvm.scala.collection.Seq.empty()
)
else:
return self._deequAnalyzers.Compliance(
self.instance,
self.predicate,
self._jvm.scala.Option.apply(self.where)
)


class Correlation(_AnalyzerObject):
Expand Down Expand Up @@ -457,12 +469,22 @@ def _analyzer_jvm(self):
"""
if not self.maxDetailBins:
self.maxDetailBins = getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$3")()
return self._deequAnalyzers.Histogram(
self.column,
self._jvm.scala.Option.apply(self.binningUdf),
self.maxDetailBins,
self._jvm.scala.Option.apply(self.where),
)
if SPARK_VERSION == "3.3":
return self._deequAnalyzers.Histogram(
self.column,
self._jvm.scala.Option.apply(self.binningUdf),
self.maxDetailBins,
self._jvm.scala.Option.apply(self.where),
getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$5")(),
getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$6")()
)
else:
return self._deequAnalyzers.Histogram(
self.column,
self._jvm.scala.Option.apply(self.binningUdf),
self.maxDetailBins,
self._jvm.scala.Option.apply(self.where)
)


class KLLParameters:
Expand Down Expand Up @@ -553,7 +575,17 @@ def _analyzer_jvm(self):
:return self
"""
return self._deequAnalyzers.MaxLength(self.column, self._jvm.scala.Option.apply(self.where))
if SPARK_VERSION == "3.3":
return self._deequAnalyzers.MaxLength(
self.column,
self._jvm.scala.Option.apply(self.where),
self._jvm.scala.Option.apply(None)
)
else:
return self._deequAnalyzers.MaxLength(
self.column,
self._jvm.scala.Option.apply(self.where)
)


class Mean(_AnalyzerObject):
Expand Down Expand Up @@ -619,7 +651,17 @@ def _analyzer_jvm(self):
:return self
"""
return self._deequAnalyzers.MinLength(self.column, self._jvm.scala.Option.apply(self.where))
if SPARK_VERSION == "3.3":
return self._deequAnalyzers.MinLength(
self.column,
self._jvm.scala.Option.apply(self.where),
self._jvm.scala.Option.apply(None)
)
else:
return self._deequAnalyzers.MinLength(
self.column,
self._jvm.scala.Option.apply(self.where)
)


class MutualInformation(_AnalyzerObject):
Expand Down
29 changes: 25 additions & 4 deletions pydeequ/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydeequ.check_functions import is_one
from pydeequ.scala_utils import ScalaFunction1, to_scala_seq

from pydeequ.configs import SPARK_VERSION

# TODO implement custom assertions
# TODO implement all methods without outside class dependencies
Expand Down Expand Up @@ -418,7 +418,11 @@ def hasMinLength(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.hasMinLength(column, assertion_func, hint)
if SPARK_VERSION == "3.3":
self._Check = self._Check.hasMinLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None))
else:
self._Check = self._Check.hasMinLength(column, assertion_func, hint)

return self

def hasMaxLength(self, column, assertion, hint=None):
Expand All @@ -433,7 +437,10 @@ def hasMaxLength(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.hasMaxLength(column, assertion_func, hint)
if SPARK_VERSION == "3.3":
self._Check = self._Check.hasMaxLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None))
else:
self._Check = self._Check.hasMaxLength(column, assertion_func, hint)
return self

def hasMin(self, column, assertion, hint=None):
Expand Down Expand Up @@ -558,7 +565,21 @@ def satisfies(self, columnCondition, constraintName, assertion=None, hint=None):
else getattr(self._Check, "satisfies$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.satisfies(columnCondition, constraintName, assertion_func, hint)
if SPARK_VERSION == "3.3":
self._Check = self._Check.satisfies(
columnCondition,
constraintName,
assertion_func,
hint,
self._jvm.scala.collection.Seq.empty()
)
else:
self._Check = self._Check.satisfies(
columnCondition,
constraintName,
assertion_func,
hint
)
return self

def hasPattern(self, column, pattern, assertion=None, name=None, hint=None):
Expand Down
3 changes: 2 additions & 1 deletion pydeequ/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


SPARK_TO_DEEQU_COORD_MAPPING = {
"3.3": "com.amazon.deequ:deequ:2.0.3-spark-3.3",
"3.3": "com.amazon.deequ:deequ:2.0.4-spark-3.3",
"3.2": "com.amazon.deequ:deequ:2.0.1-spark-3.2",
"3.1": "com.amazon.deequ:deequ:2.0.0-spark-3.1",
"3.0": "com.amazon.deequ:deequ:1.2.2-spark-3.0",
Expand Down Expand Up @@ -40,5 +40,6 @@ def _get_deequ_maven_config():
)


SPARK_VERSION = _get_spark_version()
DEEQU_MAVEN_COORD = _get_deequ_maven_config()
IS_DEEQU_V1 = re.search("com\.amazon\.deequ\:deequ\:1.*", DEEQU_MAVEN_COORD) is not None
2 changes: 2 additions & 0 deletions pydeequ/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def __init__(self, spark_session: SparkSession):
self._profiles = []
self.columnProfileClasses = {
"StandardColumnProfile": StandardColumnProfile,
"StringColumnProfile": StandardColumnProfile,
"NumericColumnProfile": NumericColumnProfile,

}

def _columnProfilesFromColumnRunBuilderRun(self, run):
Expand Down

0 comments on commit 0cdb2db

Please sign in to comment.