diff --git a/Dockerfile b/Dockerfile index a7a236a..bdd9099 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,12 +16,14 @@ RUN pip3 --version RUN java -version RUN pip install poetry==1.7.1 -COPY . /python-deequ +RUN mkdir python-deequ +COPY pyproject.toml /python-deequ +COPY poetry.lock /python-deequ WORKDIR python-deequ -RUN poetry lock --no-update -RUN poetry install -RUN poetry add pyspark==3.3 +RUN poetry install -vvv +RUN poetry add pyspark==3.3 -vvv ENV SPARK_VERSION=3.3 +COPY . /python-deequ CMD poetry run python -m pytest -s tests diff --git a/pydeequ/analyzers.py b/pydeequ/analyzers.py index efd1361..4289094 100644 --- a/pydeequ/analyzers.py +++ b/pydeequ/analyzers.py @@ -10,7 +10,7 @@ from pydeequ.repository import MetricsRepository, ResultKey from enum import Enum from pydeequ.scala_utils import to_scala_seq - +from pydeequ.configs import SPARK_VERSION class _AnalyzerObject: """ @@ -303,7 +303,19 @@ def _analyzer_jvm(self): :return self """ - return self._deequAnalyzers.Compliance(self.instance, self.predicate, self._jvm.scala.Option.apply(self.where)) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.Compliance( + self.instance, + self.predicate, + self._jvm.scala.Option.apply(self.where), + self._jvm.scala.collection.Seq.empty() + ) + else: + return self._deequAnalyzers.Compliance( + self.instance, + self.predicate, + self._jvm.scala.Option.apply(self.where) + ) class Correlation(_AnalyzerObject): @@ -457,12 +469,22 @@ def _analyzer_jvm(self): """ if not self.maxDetailBins: self.maxDetailBins = getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$3")() - return self._deequAnalyzers.Histogram( - self.column, - self._jvm.scala.Option.apply(self.binningUdf), - self.maxDetailBins, - self._jvm.scala.Option.apply(self.where), - ) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.Histogram( + self.column, + self._jvm.scala.Option.apply(self.binningUdf), + self.maxDetailBins, + self._jvm.scala.Option.apply(self.where), + getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$5")(), + getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$6")() + ) + else: + return self._deequAnalyzers.Histogram( + self.column, + self._jvm.scala.Option.apply(self.binningUdf), + self.maxDetailBins, + self._jvm.scala.Option.apply(self.where) + ) class KLLParameters: @@ -553,7 +575,17 @@ def _analyzer_jvm(self): :return self """ - return self._deequAnalyzers.MaxLength(self.column, self._jvm.scala.Option.apply(self.where)) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.MaxLength( + self.column, + self._jvm.scala.Option.apply(self.where), + self._jvm.scala.Option.apply(None) + ) + else: + return self._deequAnalyzers.MaxLength( + self.column, + self._jvm.scala.Option.apply(self.where) + ) class Mean(_AnalyzerObject): @@ -619,7 +651,17 @@ def _analyzer_jvm(self): :return self """ - return self._deequAnalyzers.MinLength(self.column, self._jvm.scala.Option.apply(self.where)) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.MinLength( + self.column, + self._jvm.scala.Option.apply(self.where), + self._jvm.scala.Option.apply(None) + ) + else: + return self._deequAnalyzers.MinLength( + self.column, + self._jvm.scala.Option.apply(self.where) + ) class MutualInformation(_AnalyzerObject): diff --git a/pydeequ/checks.py b/pydeequ/checks.py index abf94d0..a95c178 100644 --- a/pydeequ/checks.py +++ b/pydeequ/checks.py @@ -6,7 +6,7 @@ from pydeequ.check_functions import is_one from pydeequ.scala_utils import ScalaFunction1, to_scala_seq - +from pydeequ.configs import SPARK_VERSION # TODO implement custom assertions # TODO implement all methods without outside class dependencies @@ -418,7 +418,11 @@ def hasMinLength(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) - self._Check = self._Check.hasMinLength(column, assertion_func, hint) + if SPARK_VERSION == "3.3": + self._Check = self._Check.hasMinLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None)) + else: + self._Check = self._Check.hasMinLength(column, assertion_func, hint) + return self def hasMaxLength(self, column, assertion, hint=None): @@ -433,7 +437,10 @@ def hasMaxLength(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) - self._Check = self._Check.hasMaxLength(column, assertion_func, hint) + if SPARK_VERSION == "3.3": + self._Check = self._Check.hasMaxLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None)) + else: + self._Check = self._Check.hasMaxLength(column, assertion_func, hint) return self def hasMin(self, column, assertion, hint=None): @@ -558,7 +565,21 @@ def satisfies(self, columnCondition, constraintName, assertion=None, hint=None): else getattr(self._Check, "satisfies$default$3")() ) hint = self._jvm.scala.Option.apply(hint) - self._Check = self._Check.satisfies(columnCondition, constraintName, assertion_func, hint) + if SPARK_VERSION == "3.3": + self._Check = self._Check.satisfies( + columnCondition, + constraintName, + assertion_func, + hint, + self._jvm.scala.collection.Seq.empty() + ) + else: + self._Check = self._Check.satisfies( + columnCondition, + constraintName, + assertion_func, + hint + ) return self def hasPattern(self, column, pattern, assertion=None, name=None, hint=None): diff --git a/pydeequ/configs.py b/pydeequ/configs.py index c3c885d..49cb277 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -5,7 +5,7 @@ SPARK_TO_DEEQU_COORD_MAPPING = { - "3.3": "com.amazon.deequ:deequ:2.0.3-spark-3.3", + "3.3": "com.amazon.deequ:deequ:2.0.4-spark-3.3", "3.2": "com.amazon.deequ:deequ:2.0.1-spark-3.2", "3.1": "com.amazon.deequ:deequ:2.0.0-spark-3.1", "3.0": "com.amazon.deequ:deequ:1.2.2-spark-3.0", @@ -40,5 +40,6 @@ def _get_deequ_maven_config(): ) +SPARK_VERSION = _get_spark_version() DEEQU_MAVEN_COORD = _get_deequ_maven_config() IS_DEEQU_V1 = re.search("com\.amazon\.deequ\:deequ\:1.*", DEEQU_MAVEN_COORD) is not None diff --git a/pydeequ/profiles.py b/pydeequ/profiles.py index a4a2056..fbbfd84 100644 --- a/pydeequ/profiles.py +++ b/pydeequ/profiles.py @@ -241,7 +241,9 @@ def __init__(self, spark_session: SparkSession): self._profiles = [] self.columnProfileClasses = { "StandardColumnProfile": StandardColumnProfile, + "StringColumnProfile": StandardColumnProfile, "NumericColumnProfile": NumericColumnProfile, + } def _columnProfilesFromColumnRunBuilderRun(self, run):