diff --git a/DIFF.md b/DIFF.md index cad58c39..463aafca 100644 --- a/DIFF.md +++ b/DIFF.md @@ -406,8 +406,6 @@ The latter variant is prefixed with `_with_options`. * `def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame:` * `def diffwith(self: DataFrame, other: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame` -Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server). - ## Diff Spark application There is also a Spark application that can be used to create a diff DataFrame. The application reads two DataFrames diff --git a/README.md b/README.md index 8e0052f3..5882299a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python: -**[Diff](DIFF.md) [[*]](#spark-connect-server):** A `diff` transformation and application for `Dataset`s that computes the differences between +**[Diff](DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between two datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other. **[SortedGroups](GROUPS.md):** A `groupByKey` transformation that groups rows by a key while providing diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py index 45fa6d6d..5800deac 100644 --- a/python/gresearch/spark/__init__.py +++ b/python/gresearch/spark/__init__.py @@ -21,7 +21,7 @@ import time from contextlib import contextmanager from pathlib import Path -from typing import Any, Union, List, Optional, Mapping, TYPE_CHECKING +from typing import Any, Union, List, Optional, Mapping, Iterable, TYPE_CHECKING from py4j.java_gateway import JVMView, JavaObject from pyspark import __version__ @@ -30,6 +30,7 @@ from pyspark.sql import DataFrame, DataFrameReader, SQLContext from pyspark.sql.column import Column, _to_java_column from pyspark.sql.context import SQLContext +from pyspark import SparkConf from pyspark.sql.functions import col, count, lit, when from pyspark.sql.session import SparkSession from pyspark.storagelevel import StorageLevel @@ -106,6 +107,46 @@ def _to_map(jvm: JVMView, map: Mapping[Any, Any]) -> JavaObject: return jvm.scala.collection.JavaConverters.mapAsScalaMap(map) +def backticks(*name_parts: str) -> str: + return '.'.join([f'`{part}`' + if '.' in part and not part.startswith('`') and not part.endswith('`') + else part + for part in name_parts]) + + +def distinct_prefix_for(existing: List[str]) -> str: + # count number of suffix _ for each existing column name + length = 1 + if existing: + length = max([len(name) - len(name.lstrip('_')) for name in existing]) + 1 + # return string with one more _ than that + return '_' * length + + +def handle_configured_case_sensitivity(column_name: str, case_sensitive: bool) -> str: + """ + Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is + deactivated, it lower-cases the given column name and no-ops otherwise. + """ + if case_sensitive: + return column_name + return column_name.lower() + + +def list_contains_case_sensitivity(column_names: Iterable[str], columnName: str, case_sensitive: bool) -> bool: + return handle_configured_case_sensitivity(columnName, case_sensitive) in [handle_configured_case_sensitivity(c, case_sensitive) for c in column_names] + + +def list_filter_case_sensitivity(column_names: Iterable[str], filter: Iterable[str], case_sensitive: bool) -> List[str]: + filter_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in filter} + return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) in filter_set] + + +def list_diff_case_sensitivity(column_names: Iterable[str], other: Iterable[str], case_sensitive: bool) -> List[str]: + other_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in other} + return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) not in other_set] + + def dotnet_ticks_to_timestamp(tick_column: Union[str, Column]) -> Column: """ Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be @@ -386,12 +427,18 @@ def with_row_numbers(self: DataFrame, ConnectDataFrame.with_row_numbers = with_row_numbers +def session(self: DataFrame) -> SparkSession: + return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx.sparkSession + + def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]: return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx +DataFrame.session = session DataFrame.session_or_ctx = session_or_ctx if has_connect: + ConnectDataFrame.session = session ConnectDataFrame.session_or_ctx = session_or_ctx diff --git a/python/gresearch/spark/diff/__init__.py b/python/gresearch/spark/diff/__init__.py index 334e3368..0fe36e44 100644 --- a/python/gresearch/spark/diff/__init__.py +++ b/python/gresearch/spark/diff/__init__.py @@ -14,13 +14,16 @@ import dataclasses from dataclasses import dataclass from enum import Enum -from typing import Optional, Dict, Mapping, Any, Callable, Union, Iterable, overload +from functools import reduce +from typing import Optional, Dict, Mapping, Any, Callable, List, Tuple, Union, Iterable, overload from py4j.java_gateway import JavaObject, JVMView -from pyspark.sql import DataFrame -from pyspark.sql.types import DataType +from pyspark.sql import DataFrame, Column +from pyspark.sql.functions import col, lit, when, concat, coalesce, array, struct +from pyspark.sql.types import DataType, StructField, ArrayType -from gresearch.spark import _get_jvm, _to_seq, _to_map +from gresearch.spark import _get_jvm, _to_seq, _to_map, backticks, distinct_prefix_for, \ + handle_configured_case_sensitivity, list_contains_case_sensitivity, list_filter_case_sensitivity, list_diff_case_sensitivity from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator try: @@ -36,8 +39,8 @@ class DiffMode(Enum): LeftSide = "LeftSide" RightSide = "RightSide" - # the actual default enum value is defined in Java - Default = "Default" + # should be in sync with default defined in Java + Default = ColumnByColumn def _to_java(self, jvm: JVMView) -> JavaObject: return jvm.uk.co.gresearch.spark.diff.DiffMode.withNameOption(self.name).get() @@ -238,28 +241,13 @@ def with_column_name_comparator(self, comparator: DiffComparator, *column_name: column_name_comparators.update({dt: comparator for dt in column_name}) return dataclasses.replace(self, column_name_comparators=column_name_comparators) - def _to_java(self, jvm: JVMView) -> JavaObject: - return jvm.uk.co.gresearch.spark.diff.DiffOptions( - self.diff_column, - self.left_column_prefix, - self.right_column_prefix, - self.insert_diff_value, - self.change_diff_value, - self.delete_diff_value, - self.nochange_diff_value, - jvm.scala.Option.apply(self.change_column), - self.diff_mode._to_java(jvm), - self.sparse_mode, - self.default_comparator._to_java(jvm), - self._to_java_map(jvm, self.data_type_comparators, key_to_java=self._to_java_data_type), - self._to_java_map(jvm, self.column_name_comparators) - ) - - def _to_java_map(self, jvm: JVMView, map: Mapping[Any, DiffComparator], key_to_java: Callable[[JVMView, Any], Any] = lambda j, x: x) -> JavaObject: - return _to_map(jvm, {key_to_java(jvm, key): cmp._to_java(jvm) for key, cmp in map.items()}) - - def _to_java_data_type(self, jvm: JVMView, dt: DataType) -> JavaObject: - return jvm.org.apache.spark.sql.types.DataType.fromJson(dt.json()) + def comparator_for(self, column: StructField) -> DiffComparator: + cmp = self.column_name_comparators.get(column.name) + if cmp is None: + cmp = self.data_type_comparators.get(column.dataType) + if cmp is None: + cmp = self.default_comparator + return cmp class Differ: @@ -272,10 +260,6 @@ class Differ: def __init__(self, options: DiffOptions = None): self._options = options or DiffOptions() - def _to_java(self, jvm: JVMView) -> JavaObject: - jdo = self._options._to_java(jvm) - return jvm.uk.co.gresearch.spark.diff.Differ(jdo) - @overload def diff(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame: ... @@ -348,15 +332,18 @@ def diff(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Union[s :return: the diff DataFrame :rtype DataFrame """ - if len(id_or_ignore_columns) == 2 and all([isinstance(lst, Iterable) for lst in id_or_ignore_columns]): + if len(id_or_ignore_columns) == 2 and all([isinstance(lst, Iterable) and not isinstance(lst, str) for lst in id_or_ignore_columns]): id_columns, ignore_columns = id_or_ignore_columns else: id_columns, ignore_columns = (id_or_ignore_columns, []) - jvm = _get_jvm(left) - jdiffer = self._to_java(jvm) - jdf = jdiffer.diff(left._jdf, right._jdf, _to_seq(jvm, list(id_columns)), _to_seq(jvm, list(ignore_columns))) - return DataFrame(jdf, left.session_or_ctx()) + return self._do_diff(left, right, id_columns, ignore_columns) + + @staticmethod + def _columns_of_side(df: DataFrame, id_columns: List[str], side_prefix: str) -> List[Column]: + prefix = side_prefix + '_' + return [col(c) if c in id_columns else col(c).alias(c.replace(prefix, "")) + for c in df.columns if c in id_columns or c.startswith(side_prefix)] @overload def diffwith(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame: ... @@ -386,14 +373,246 @@ def diffwith(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Uni else: id_columns, ignore_columns = (id_or_ignore_columns, []) - jvm = _get_jvm(left) - jdiffer = self._to_java(jvm) - jdf = jdiffer.diffWith(left._jdf, right._jdf, _to_seq(jvm, list(id_columns)), _to_seq(jvm, list(ignore_columns))) - df = DataFrame(jdf, left.sql_ctx) - return df \ - .withColumnRenamed('_1', self._options.diff_column) \ - .withColumnRenamed('_2', self._options.left_column_prefix) \ - .withColumnRenamed('_3', self._options.right_column_prefix) + diff = self._do_diff(left, right, id_columns, ignore_columns) + left_columns = self._columns_of_side(diff, id_columns, self._options.left_column_prefix) + right_columns = self._columns_of_side(diff, id_columns, self._options.right_column_prefix) + diff_column = col(self._options.diff_column) + + left_struct = when(diff_column == self._options.insert_diff_value, lit(None)) \ + .otherwise(struct(*left_columns)) \ + .alias(self._options.left_column_prefix) + right_struct = when(diff_column == self._options.delete_diff_value, lit(None)) \ + .otherwise(struct(*right_columns)) \ + .alias(self._options.right_column_prefix) + return diff.select(diff_column, left_struct, right_struct) + + def _check_schema(self, left: DataFrame, right: DataFrame, id_columns: List[str], ignore_columns: List[str], case_sensitive: bool): + def require(result: bool, message: str) -> None: + if not result: + raise ValueError(message) + + require( + len(left.columns) == len(set(left.columns)) and len(right.columns) == len(set(right.columns)), + f"The datasets have duplicate columns.\n" + + f"Left column names: {', '.join(left.columns)}\n" + + f"Right column names: {', '.join(right.columns)}") + + left_non_ignored = list_diff_case_sensitivity(left.columns, ignore_columns, case_sensitive) + right_non_ignored = list_diff_case_sensitivity(right.columns, ignore_columns, case_sensitive) + + except_ignored_columns_msg = ' except ignored columns' if ignore_columns else '' + + require( + len(left_non_ignored) == len(right_non_ignored), + "The number of columns doesn't match.\n" + + f"Left column names{except_ignored_columns_msg} ({len(left_non_ignored)}): {', '.join(left_non_ignored)}\n" + + f"Right column names{except_ignored_columns_msg} ({len(right_non_ignored)}): {', '.join(right_non_ignored)}" + ) + + require(len(left_non_ignored) > 0, f"The schema{except_ignored_columns_msg} must not be empty") + + # column types must match but we ignore the nullability of columns + left_fields = {handle_configured_case_sensitivity(field.name, case_sensitive): field.dataType + for field in left.schema.fields + if not list_contains_case_sensitivity(ignore_columns, field.name, case_sensitive)} + right_fields = {handle_configured_case_sensitivity(field.name, case_sensitive): field.dataType + for field in right.schema.fields + if not list_contains_case_sensitivity(ignore_columns, field.name, case_sensitive)} + left_extra_schema = set(left_fields.items()) - set(right_fields.items()) + right_extra_schema = set(right_fields.items()) - set(left_fields.items()) + require( + len(left_extra_schema) == 0 and len(right_extra_schema) == 0, + "The datasets do not have the same schema.\n" + + f"Left extra columns: {', '.join([f'{f} ({t.typeName()})' for f, t in sorted(list(left_extra_schema))])}\n" + + f"Right extra columns: {', '.join([f'{f} ({t.typeName()})' for f, t in sorted(list(right_extra_schema))])}") + + columns = left_non_ignored + pk_columns = id_columns or columns + non_pk_columns = list_diff_case_sensitivity(columns, pk_columns, case_sensitive) + missing_id_columns = list_diff_case_sensitivity(pk_columns, columns, case_sensitive) + require( + len(missing_id_columns) == 0, + f"Some id columns do not exist: {', '.join(missing_id_columns)} missing among {', '.join(columns)}" + ) + + missing_ignore_columns = list_diff_case_sensitivity(ignore_columns, left.columns + right.columns, case_sensitive) + require( + len(missing_ignore_columns) == 0, + f"Some ignore columns do not exist: {', '.join(missing_ignore_columns)} " + + f"missing among {', '.join(sorted(list(set(left_non_ignored + right_non_ignored))))}" + ) + + require( + not list_contains_case_sensitivity(pk_columns, self._options.diff_column, case_sensitive), + f"The id columns must not contain the diff column name '{self._options.diff_column}': {', '.join(pk_columns)}" + ) + require( + self._options.change_column is None or not list_contains_case_sensitivity(pk_columns, self._options.change_column, case_sensitive), + f"The id columns must not contain the change column name '{self._options.change_column}': {', '.join(pk_columns)}" + ) + diff_value_columns = self._get_diff_value_columns(pk_columns, non_pk_columns, left, right, ignore_columns, case_sensitive) + diff_value_columns = {n for n, t in diff_value_columns} + + if self._options.diff_mode in [DiffMode.LeftSide, DiffMode.RightSide]: + require( + not list_contains_case_sensitivity(diff_value_columns, self._options.diff_column, case_sensitive), + f"The {'left' if self._options.diff_mode == DiffMode.LeftSide else 'right'} " + + f"non-id columns must not contain the diff column name '{self._options.diff_column}': " + + f"{', '.join(list_diff_case_sensitivity((left if self._options.diff_mode == DiffMode.LeftSide else right).columns, id_columns, case_sensitive))}" + ) + + require( + self._options.change_column is None or not list_contains_case_sensitivity(diff_value_columns, self._options.change_column, case_sensitive), + f"The {'left' if self._options.diff_mode == DiffMode.LeftSide else 'right'} " + + f"non-id columns must not contain the change column name '{self._options.change_column}': " + + f"{', '.join(list_diff_case_sensitivity((left if self._options.diff_mode == DiffMode.LeftSide else right).columns, id_columns, case_sensitive))}" + ) + else: + require( + not list_contains_case_sensitivity(diff_value_columns, self._options.diff_column, case_sensitive), + f"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', " + + f"together with these non-id columns must not produce the diff column name '{self._options.diff_column}': " + + f"{', '.join(non_pk_columns)}" + ) + + require( + self._options.change_column is None or not list_contains_case_sensitivity(diff_value_columns, self._options.change_column, case_sensitive), + f"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', " + + f"together with these non-id columns must not produce the change column name '{self._options.change_column}': " + + f"{', '.join(non_pk_columns)}" + ) + + require( + all(not list_contains_case_sensitivity(pk_columns, c, case_sensitive) for c in diff_value_columns), + f"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', " + + f"together with these non-id columns must not produce any id column name '{', '.join(pk_columns)}': " + + f"{', '.join(non_pk_columns)}" + ) + + def _get_change_column(self, + exists_column_name: str, + value_columns_with_comparator: List[Tuple[str, DiffComparator]], + left: DataFrame, + right: DataFrame) -> Optional[Column]: + if self._options.change_column is None: + return None + if not self._options.change_column: + return array().cast(ArrayType(StringType, containsNull = false)).alias(self._options.change_column) + return when(left[exists_column_name].isNull() | right[exists_column_name].isNull(), lit(None)) \ + .otherwise( + concat(*[when(cmp.equiv(left[c], right[c]), array()).otherwise(array(lit(c))) + for (c, cmp) in value_columns_with_comparator])) \ + .alias(self._options.change_column) + + def _do_diff(self, left: DataFrame, right: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame: + case_sensitive = left.session().conf.get("spark.sql.caseSensitive") == "true" + self._check_schema(left, right, id_columns, ignore_columns, case_sensitive) + + columns = list_diff_case_sensitivity(left.columns, ignore_columns, case_sensitive) + pk_columns = id_columns or columns + value_columns = list_diff_case_sensitivity(columns, pk_columns, case_sensitive) + value_struct_fields = {f.name: f for f in left.schema.fields} + value_columns_with_comparator = [(c, self._options.comparator_for(value_struct_fields[c])) for c in value_columns] + + exists_column_name = distinct_prefix_for(left.columns) + "exists" + left_with_exists = left.withColumn(exists_column_name, lit(1)) + right_with_exists = right.withColumn(exists_column_name, lit(1)) + join_condition = reduce(lambda l, r: l & r, + [left_with_exists[c].eqNullSafe(right_with_exists[c]) + for c in pk_columns]) + un_changed = reduce(lambda l, r: l & r, + [cmp.equiv(left_with_exists[c], right_with_exists[c]) + for (c, cmp) in value_columns_with_comparator], + lit(True)) + change_condition = ~un_changed + + diff_action_column = \ + when(left_with_exists[exists_column_name].isNull(), lit(self._options.insert_diff_value)) \ + .when(right_with_exists[exists_column_name].isNull(), lit(self._options.delete_diff_value)) \ + .when(change_condition, lit(self._options.change_diff_value)) \ + .otherwise(lit(self._options.nochange_diff_value)) \ + .alias(self._options.diff_column) + + diff_columns = [c[1] for c in self._get_diff_columns(pk_columns, value_columns, left, right, ignore_columns, case_sensitive)] + # turn this column into a list of one or none column so we can easily concat it below with diffActionColumn and diffColumns + change_column = self._get_change_column(exists_column_name, value_columns_with_comparator, left_with_exists, right_with_exists) + change_columns = [change_column] if change_column is not None else [] + + return left_with_exists \ + .join(right_with_exists, join_condition, "fullouter") \ + .select(*([diff_action_column] + change_columns + diff_columns)) + + def _get_diff_id_columns(self, pk_columns: List[str], + left: DataFrame, + right: DataFrame) -> List[Tuple[str, Column]]: + return [(c, coalesce(left[c], right[c]).alias(c)) for c in pk_columns] + + def _get_diff_value_columns(self, pk_columns: List[str], + value_columns: List[str], + left: DataFrame, + right: DataFrame, + ignore_columns: List[str], + case_sensitive: bool) -> List[Tuple[str, Column]]: + left_value_columns = list_filter_case_sensitivity(left.columns, value_columns, case_sensitive) + right_value_columns = list_filter_case_sensitivity(right.columns, value_columns, case_sensitive) + + left_non_pk_columns = list_diff_case_sensitivity(left.columns, pk_columns, case_sensitive) + right_non_pk_columns = list_diff_case_sensitivity(right.columns, pk_columns, case_sensitive) + + left_ignored_columns = list_filter_case_sensitivity(left.columns, ignore_columns, case_sensitive) + right_ignored_columns = list_filter_case_sensitivity(right.columns, ignore_columns, case_sensitive) + left_values = {handle_configured_case_sensitivity(c, case_sensitive): (c, when(~(left[c].eqNullSafe(right[c])), left[c]) if self._options.sparse_mode else left[c]) for c in left_non_pk_columns} + right_values = {handle_configured_case_sensitivity(c, case_sensitive): (c, when(~(left[c].eqNullSafe(right[c])), right[c]) if self._options.sparse_mode else right[c]) for c in right_non_pk_columns} + + def alias(prefix: Optional[str], values: Dict[str, Tuple[str, Column]]) -> Callable[[str], Tuple[str, Column]]: + def func(name: str) -> (str, Column): + name, column = values[handle_configured_case_sensitivity(name, case_sensitive)] + alias = name if prefix is None else f'{prefix}_{name}' + return alias, column.alias(alias) + + return func + + def alias_left(name: str) -> (str, Column): + return alias(self._options.left_column_prefix, left_values)(name) + + def alias_right(name: str) -> (str, Column): + return alias(self._options.right_column_prefix, right_values)(name) + + prefixed_left_ignored_columns = [alias_left(c) for c in left_ignored_columns] + prefixed_right_ignored_columns = [alias_right(c) for c in right_ignored_columns] + + if self._options.diff_mode == DiffMode.ColumnByColumn: + non_id_columns = \ + [c for vc in value_columns for c in [alias_left(vc), alias_right(vc)]] + \ + [c for ic in ignore_columns for c in ( + ([alias_left(ic)] if list_contains_case_sensitivity(left_ignored_columns, ic, case_sensitive) else []) + + ([alias_right(ic)] if list_contains_case_sensitivity(right_ignored_columns, ic, case_sensitive) else []) + )] + elif self._options.diff_mode == DiffMode.SideBySide: + non_id_columns = \ + [alias_left(c) for c in left_value_columns] + prefixed_left_ignored_columns + \ + [alias_right(c) for c in right_value_columns] + prefixed_right_ignored_columns + elif self._options.diff_mode == DiffMode.LeftSide: + non_id_columns = \ + [alias(None, left_values)(c) for c in value_columns] +\ + [alias(None, left_values)(c) for c in left_ignored_columns] + elif self._options.diff_mode == DiffMode.RightSide: + non_id_columns = \ + [alias(None, right_values)(c) for c in value_columns] + \ + [alias(None, right_values)(c) for c in right_ignored_columns] + else: + raise RuntimeError(f'Unsupported diff mode: {self._options.diff_mode}') + + return non_id_columns + + def _get_diff_columns(self, pk_columns: List[str], + value_columns: List[str], + left: DataFrame, + right: DataFrame, + ignore_columns: List[str], + case_sensitive: bool) -> List[Tuple[str, Column]]: + return self._get_diff_id_columns(pk_columns, left, right) + \ + self._get_diff_value_columns(pk_columns, value_columns, left, right, ignore_columns, case_sensitive) @overload diff --git a/python/gresearch/spark/diff/comparator/__init__.py b/python/gresearch/spark/diff/comparator/__init__.py index 3a7988c9..e7d85283 100644 --- a/python/gresearch/spark/diff/comparator/__init__.py +++ b/python/gresearch/spark/diff/comparator/__init__.py @@ -18,12 +18,14 @@ from py4j.java_gateway import JVMView, JavaObject +from pyspark.sql import Column +from pyspark.sql.functions import abs, greatest, lit from pyspark.sql.types import DataType class DiffComparator(abc.ABC): @abc.abstractmethod - def _to_java(self, jvm: JVMView) -> JavaObject: + def equiv(self, left: Column, right: Column) -> Column: pass @@ -53,14 +55,15 @@ def map(key_type: DataType, value_type: DataType, key_order_sensitive: bool = Fa return MapDiffComparator(key_type, value_type, key_order_sensitive) -class DefaultDiffComparator(DiffComparator): - def _to_java(self, jvm: JVMView) -> JavaObject: - return jvm.uk.co.gresearch.spark.diff.DiffComparators.default() +class NullSafeEqualDiffComparator(DiffComparator): + def equiv(self, left: Column, right: Column) -> Column: + return left.eqNullSafe(right) -class NullSafeEqualDiffComparator(DiffComparator): +class DefaultDiffComparator(NullSafeEqualDiffComparator): + # for testing only def _to_java(self, jvm: JVMView) -> JavaObject: - return jvm.uk.co.gresearch.spark.diff.DiffComparators.nullSafeEqual() + return jvm.uk.co.gresearch.spark.diff.DiffComparators.default() @dataclass(frozen=True) @@ -81,16 +84,25 @@ def as_inclusive(self) -> 'EpsilonDiffComparator': def as_exclusive(self) -> 'EpsilonDiffComparator': return dataclasses.replace(self, inclusive=False) - def _to_java(self, jvm: JVMView) -> JavaObject: - return jvm.uk.co.gresearch.spark.diff.comparator.EpsilonDiffComparator(self.epsilon, self.relative, self.inclusive) + def equiv(self, left: Column, right: Column) -> Column: + threshold = greatest(abs(left), abs(right)) * self.epsilon if self.relative else lit(self.epsilon) + + def inclusive_epsilon(diff: Column) -> Column: + return diff.__le__(threshold) + + def exclusive_epsilon(diff: Column) -> Column: + return diff.__lt__(threshold) + + in_epsilon = inclusive_epsilon if self.inclusive else exclusive_epsilon + return left.isNull() & right.isNull() | left.isNotNull() & right.isNotNull() & in_epsilon(abs(left - right)) @dataclass(frozen=True) class StringDiffComparator(DiffComparator): whitespace_agnostic: bool - def _to_java(self, jvm: JVMView) -> JavaObject: - return jvm.uk.co.gresearch.spark.diff.DiffComparators.string(self.whitespace_agnostic) + def equiv(self, left: Column, right: Column) -> Column: + return left.eqNullSafe(right) @dataclass(frozen=True) @@ -104,9 +116,8 @@ def as_inclusive(self) -> 'DurationDiffComparator': def as_exclusive(self) -> 'DurationDiffComparator': return dataclasses.replace(self, inclusive=False) - def _to_java(self, jvm: JVMView) -> JavaObject: - jduration = jvm.java.time.Duration.parse(self.duration) - return jvm.uk.co.gresearch.spark.diff.comparator.DurationDiffComparator(jduration, self.inclusive) + def equiv(self, left: Column, right: Column) -> Column: + return left.eqNullSafe(right) @dataclass(frozen=True) @@ -115,10 +126,5 @@ class MapDiffComparator(DiffComparator): value_type: DataType key_order_sensitive: bool - def _to_java(self, jvm: JVMView) -> JavaObject: - from pyspark.sql import SparkSession - - jfromjson = jvm.org.apache.spark.sql.types.__getattr__("DataType$").__getattr__("MODULE$").fromJson - jkeytype = jfromjson(self.key_type.json()) - jvaluetype = jfromjson(self.value_type.json()) - return jvm.uk.co.gresearch.spark.diff.DiffComparators.map(jkeytype, jvaluetype, self.key_order_sensitive) + def equiv(self, left: Column, right: Column) -> Column: + return left.eqNullSafe(right) diff --git a/python/test/spark_common.py b/python/test/spark_common.py index 76735beb..2716ca1e 100644 --- a/python/test/spark_common.py +++ b/python/test/spark_common.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import logging import os import sys import unittest +from contextlib import contextmanager from pathlib import Path from pyspark import SparkConf @@ -26,7 +26,7 @@ logger.level = logging.INFO -@contextlib.contextmanager +@contextmanager def spark_session(): session = SparkTest.get_spark_session() try: @@ -106,3 +106,30 @@ def tearDownClass(cls): logging.info('stopping Spark session') cls.spark.stop() super(SparkTest, cls).tearDownClass() + + @contextmanager + def sql_conf(self, pairs): + """ + Copied from pyspark/testing/sqlutils available from PySpark 3.5.0 and higher. + https://github.com/apache/spark/blob/v3.5.0/python/pyspark/testing/sqlutils.py#L171 + http://www.apache.org/licenses/LICENSE-2.0 + + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) diff --git a/python/test/test_diff.py b/python/test/test_diff.py index c9871c2d..d343f472 100644 --- a/python/test/test_diff.py +++ b/python/test/test_diff.py @@ -11,20 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import re from py4j.java_gateway import JavaObject from pyspark.sql import Row -from pyspark.sql.functions import col, when -from pyspark.sql.types import IntegerType, LongType, StringType, DateType +from pyspark.sql.functions import col, when, abs +from pyspark.sql.types import IntegerType, LongType, StringType, DateType, StructField, StructType, FloatType, DoubleType from unittest import skipIf from gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparators from spark_common import SparkTest -@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Diff") class DiffTest(SparkTest): expected_diff = None @@ -62,6 +61,25 @@ def setUpClass(cls): diff_row('I', 6, None, 6.0, None, 'six'), diff_row('D', 7, 7.0, None, 'seven', None), ] + diff_change_row = Row('diff', 'change', 'id', 'left_val', 'right_val', 'left_label', 'right_label') + cls.expected_diff_change = [ + diff_change_row('C', ['val'], 1, 1.0, 1.1, 'one', 'one'), + diff_change_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'), + diff_change_row('N', [], 3, 3.0, 3.0, 'three', 'three'), + diff_change_row('C', ['val', 'label'], 4, None, 4.0, None, 'four'), + diff_change_row('C', ['val', 'label'], 5, 5.0, None, 'five', None), + diff_change_row('I', None, 6, None, 6.0, None, 'six'), + diff_change_row('D', None, 7, 7.0, None, 'seven', None), + ] + cls.expected_diff_reversed = [ + diff_row('C', 1, 1.1, 1.0, 'one', 'one'), + diff_row('C', 2, 2.0, 2.0, 'Two', 'two'), + diff_row('N', 3, 3.0, 3.0, 'three', 'three'), + diff_row('C', 4, 4.0, None, 'four', None), + diff_row('C', 5, None, 5.0, None, 'five'), + diff_row('D', 6, 6.0, None, 'six', None), + diff_row('I', 7, None, 7.0, None, 'seven'), + ] cls.expected_diff_ignored = [ diff_row('C', 1, 1.0, 1.1, 'one', 'one'), diff_row('N', 2, 2.0, 2.0, 'two', 'Two'), @@ -169,6 +187,299 @@ def setUpClass(cls): diff_in_sparse_mode_row('D', 7, 7.0, None, 'seven', None), ] + def test_check_schema(self): + @contextlib.contextmanager + def test_requirement(error_message: str): + with self.assertRaises(ValueError) as e: + yield + self.assertEqual((error_message, ), e.exception.args) + + with self.subTest("duplicate columns"): + with test_requirement("The datasets have duplicate columns.\n" + "Left column names: id, id\nRight column names: id, id"): + self.left_df.select("id", "id").diff(self.right_df.select("id", "id"), "id") + + with self.subTest("case-sensitive id column"): + with test_requirement("Some id columns do not exist: ID missing among id, val, label"): + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + self.left_df.diff(self.right_df, "ID") + + left = self.left_df.withColumnRenamed("val", "diff") + right = self.right_df.withColumnRenamed("val", "diff") + + with self.subTest("id column 'diff'"): + with test_requirement("The id columns must not contain the diff column name 'diff': id, diff, label"): + left.diff(right) + with test_requirement("The id columns must not contain the diff column name 'diff': diff"): + left.diff(right, "diff") + with test_requirement("The id columns must not contain the diff column name 'diff': diff, id"): + left.diff(right, "diff", "id") + + with self.sql_conf({"spark.sql.caseSensitive": "false"}): + with test_requirement("The id columns must not contain the diff column name 'diff': Diff, id"): + left.withColumnRenamed("diff", "Diff") \ + .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id") + + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + left.withColumnRenamed("diff", "Diff") \ + .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id") + + with self.subTest("non-id column 'diff"): + actual = left.diff(right, "id").orderBy("id") + expected_columns = ["diff", "id", "left_diff", "right_diff", "left_label", "right_label"] + self.assertEqual(actual.columns, expected_columns) + self.assertEqual(actual.collect(), self.expected_diff) + + with self.subTest("non-id column produces diff column name"): + options = DiffOptions() \ + .with_diff_column("a_val") \ + .with_left_column_prefix("a") \ + .with_right_column_prefix("b") + + with test_requirement("The column prefixes 'a' and 'b', together with these non-id columns " + + "must not produce the diff column name 'a_val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + with test_requirement("The column prefixes 'a' and 'b', together with these non-id columns " + + "must not produce the diff column name 'b_val': val, label"): + self.left_df.diff_with_options(self.right_df, options.with_diff_column("b_val"), "id") + + with self.subTest("non-id column would produce diff column name unless in left-side mode"): + options = DiffOptions() \ + .with_diff_column("a_val") \ + .with_left_column_prefix("a") \ + .with_right_column_prefix("b") \ + .with_diff_mode(DiffMode.LeftSide) + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.subTest("non-id column would produce diff column name unless in right-side mode"): + options = DiffOptions() \ + .with_diff_column("b_val") \ + .with_left_column_prefix("a") \ + .with_right_column_prefix("b") \ + .with_diff_mode(DiffMode.RightSide) + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.sql_conf({"spark.sql.caseSensitive": "false"}): + with self.subTest("case-insensitive non-id column produces diff column name"): + options = DiffOptions() \ + .with_diff_column("a_val") \ + .with_left_column_prefix("A") \ + .with_right_column_prefix("b") + with test_requirement("The column prefixes 'A' and 'b', together with these non-id columns " + + "must not produce the diff column name 'a_val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.subTest("case-insensitive non-id column would produce diff column name unless in left-side mode"): + options = DiffOptions() \ + .with_diff_column("a_val") \ + .with_left_column_prefix("A") \ + .with_right_column_prefix("B") \ + .with_diff_mode(DiffMode.LeftSide) + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.subTest("case-insensitive non-id column would produce diff column name unless in right-side mode"): + options = DiffOptions() \ + .with_diff_column("b_val") \ + .with_left_column_prefix("A") \ + .with_right_column_prefix("B") \ + .with_diff_mode(DiffMode.RightSide) + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + with self.subTest("case-sensitive non-id column produces non-conflicting diff column name"): + options = DiffOptions() \ + .with_diff_column("a_val") \ + .with_left_column_prefix("A") \ + .with_right_column_prefix("B") \ + + actual = self.left_df.diff_with_options(self.right_df, options, "id").orderBy("id") + expected_columns = ["a_val", "id", "A_val", "B_val", "A_label", "B_label"] + self.assertEqual(actual.columns, expected_columns) + self.assertEqual(actual.collect(), self.expected_diff) + + left = self.left_df.withColumnRenamed("val", "change") + right = self.right_df.withColumnRenamed("val", "change") + + with self.subTest("id column 'change'"): + options = DiffOptions() \ + .with_change_column("change") + with test_requirement("The id columns must not contain the change column name 'change': id, change, label"): + left.diff_with_options(right, options) + with test_requirement("The id columns must not contain the change column name 'change': change"): + left.diff_with_options(right, options, "change") + with test_requirement("The id columns must not contain the change column name 'change': change, id"): + left.diff_with_options(right, options, "change", "id") + + with self.sql_conf({"spark.sql.caseSensitive": "false"}): + with test_requirement("The id columns must not contain the change column name 'change': Change, id"): + left.withColumnRenamed("change", "Change") \ + .diff_with_options(right.withColumnRenamed("change", "Change"), options, "Change", "id") + + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + left.withColumnRenamed("change", "Change") \ + .diff_with_options(right.withColumnRenamed("change", "Change"), options, "Change", "id") + + with self.subTest("non-id column 'change"): + actual = left.diff_with_options(right, options, "id").orderBy("id") + expected_columns = ["diff", "change", "id", "left_change", "right_change", "left_label", "right_label"] + diff_change_row = Row(*expected_columns) + expected_diff = [ + diff_change_row('C', ['change'], 1, 1.0, 1.1, 'one', 'one'), + diff_change_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'), + diff_change_row('N', [], 3, 3.0, 3.0, 'three', 'three'), + diff_change_row('C', ['change', 'label'], 4, None, 4.0, None, 'four'), + diff_change_row('C', ['change', 'label'], 5, 5.0, None, 'five', None), + diff_change_row('I', None, 6, None, 6.0, None, 'six'), + diff_change_row('D', None, 7, 7.0, None, 'seven', None), + ] + self.assertEqual(actual.columns, expected_columns) + self.assertEqual(actual.collect(), expected_diff) + + with self.subTest("non-id column produces change column name"): + options = DiffOptions() \ + .with_change_column("a_val") \ + .with_left_column_prefix("a") \ + .with_right_column_prefix("b") + with test_requirement("The column prefixes 'a' and 'b', together with these non-id columns " + + "must not produce the change column name 'a_val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.sql_conf({"spark.sql.caseSensitive": "false"}): + with self.subTest("case-insensitive non-id column produces change column name"): + options = DiffOptions() \ + .with_change_column("a_val") \ + .with_left_column_prefix("A") \ + .with_right_column_prefix("B") + with test_requirement("The column prefixes 'A' and 'B', together with these non-id columns " + + "must not produce the change column name 'a_val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + with self.subTest("case-sensitive non-id column produces non-conflicting change column name"): + options = DiffOptions() \ + .with_change_column("a_val") \ + .with_left_column_prefix("A") \ + .with_right_column_prefix("B") + actual = self.left_df.diff_with_options(self.right_df, options, "id").orderBy("id") + expected_columns = ["diff", "a_val", "id", "A_val", "B_val", "A_label", "B_label"] + self.assertEqual(actual.columns, expected_columns) + self.assertEqual(actual.collect(), self.expected_diff_change) + + left = self.left_df.select(col("id").alias("first_id"), col("val").alias("id"), "label") + right = self.right_df.select(col("id").alias("first_id"), col("val").alias("id"), "label") + with self.subTest("non-id column produces id column name"): + options = DiffOptions() \ + .with_left_column_prefix("first") \ + .with_right_column_prefix("second") + with test_requirement("The column prefixes 'first' and 'second', together with these non-id columns " + + "must not produce any id column name 'first_id': id, label"): + left.diff_with_options(right, options, "first_id") + + with self.sql_conf({"spark.sql.caseSensitive": "false"}): + with self.subTest("case-insensitive non-id column produces id column name"): + options = DiffOptions() \ + .with_left_column_prefix("FIRST") \ + .with_right_column_prefix("SECOND") + with test_requirement("The column prefixes 'FIRST' and 'SECOND', together with these non-id columns " + + "must not produce any id column name 'first_id': id, label"): + left.diff_with_options(right, options, "first_id") + + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + with self.subTest("case-sensitive non-id column produces non-conflicting id column name"): + options = DiffOptions() \ + .with_left_column_prefix("FIRST") \ + .with_right_column_prefix("SECOND") + actual = left.diff_with_options(right, options, "first_id").orderBy("first_id") + expected_columns = ["diff", "first_id", "FIRST_id", "SECOND_id", "FIRST_label", "SECOND_label"] + self.assertEqual(actual.columns, expected_columns) + self.assertEqual(actual.collect(), self.expected_diff) + + with self.subTest("empty schema"): + with test_requirement("The schema must not be empty"): + self.left_df.select().diff(self.right_df.select()) + + with self.subTest("empty schema after ignored columns"): + with test_requirement("The schema except ignored columns must not be empty"): + self.left_df.select("id", "val").diff(self.right_df.select("id", "label"), [], ["id", "val", "label"]) + + with self.subTest("different types"): + with test_requirement("The datasets do not have the same schema.\n" + + "Left extra columns: val (double)\n" + + "Right extra columns: val (string)"): + self.left_df.select("id", "val").diff(self.right_df.select("id", col("label").alias("val"))) + + with self.subTest("ignore columns with different types"): + actual = self.left_df.select("id", "val").diff(self.right_df.select("id", col("label").alias("val")), [], ["val"]) + expected_schema = [ + ("diff", StringType()), + ("id", LongType()), + ("left_val", DoubleType()), + ("right_val", StringType()), + ] + self.assertEqual([(f.name, f.dataType) for f in actual.schema], expected_schema) + + with self.subTest("diff with different column names"): + with test_requirement("The datasets do not have the same schema.\n" + + "Left extra columns: val (double)\n" + + "Right extra columns: label (string)"): + self.left_df.select("id", "val").diff(self.right_df.select("id", "label")) + + left = self.left_df.select("id", "val", "label") + right = self.right_df.select(col("id").alias("ID"), col("val").alias("VaL"), "label") + with self.sql_conf({"spark.sql.caseSensitive": "false"}): + with self.subTest("case-insensitive column names"): + actual = left.diff(right, "id").orderBy("id") + reverse = right.diff(left, "id").orderBy("id") + self.assertEqual(actual.columns, ["diff", "id", "left_val", "right_VaL", "left_label", "right_label"]) + self.assertEqual(actual.collect(), self.expected_diff) + self.assertEqual(reverse.columns, ["diff", "id", "left_VaL", "right_val", "left_label", "right_label"]) + self.assertEqual(reverse.collect(), self.expected_diff_reversed) + + with self.sql_conf({"spark.sql.caseSensitive": "true"}): + with self.subTest("case-sensitive column names"): + with test_requirement("The datasets do not have the same schema.\n" + + "Left extra columns: id (long), val (double)\n" + + "Right extra columns: ID (long), VaL (double)"): + left.diff(right, "id") + + with self.subTest("non-existing id column"): + with test_requirement("Some id columns do not exist: does not exists missing among id, val, label"): + self.left_df.diff(self.right_df, "does not exists") + + with self.subTest("different number of columns"): + with test_requirement("The number of columns doesn't match.\n" + + "Left column names (2): id, val\n" + + "Right column names (3): id, val, label"): + self.left_df.select("id", "val").diff(self.right_df, "id") + + with self.subTest("different number of columns after ignoring columns"): + left = self.left_df.select("id", "val", col("label").alias("meta")) + right = self.right_df.select("id", col("label").alias("seq"), "val") + with test_requirement("The number of columns doesn't match.\n" + + "Left column names except ignored columns (2): id, val\n" + + "Right column names except ignored columns (3): id, seq, val"): + left.diff(right, ["id"], ["meta"]) + + with self.subTest("diff column name in value columns in left-side diff mode"): + options = DiffOptions().with_diff_column("val").with_diff_mode(DiffMode.LeftSide) + with test_requirement("The left non-id columns must not contain the diff column name 'val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.subTest("diff column name in value columns in right-side diff mode"): + options = DiffOptions().with_diff_column("val").with_diff_mode(DiffMode.RightSide) + with test_requirement("The right non-id columns must not contain the diff column name 'val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.subTest("change column name in value columns in left-side diff mode"): + options = DiffOptions().with_change_column("val").with_diff_mode(DiffMode.LeftSide) + with test_requirement("The left non-id columns must not contain the change column name 'val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + + with self.subTest("change column name in value columns in right-side diff mode"): + options = DiffOptions().with_change_column("val").with_diff_mode(DiffMode.RightSide) + with test_requirement("The right non-id columns must not contain the change column name 'val': val, label"): + self.left_df.diff_with_options(self.right_df, options, "id") + def test_dataframe_diff(self): diff = self.left_df.diff(self.right_df, 'id').orderBy('id').collect() self.assertEqual(self.expected_diff, diff) @@ -280,6 +591,7 @@ def test_differ_diff_with_sparse_mode(self): diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect() self.assertEqual(self.expected_diff_in_sparse_mode, diff) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM") def test_diff_options_default(self): jvm = self.spark._jvm joptions = jvm.uk.co.gresearch.spark.diff.DiffOptions.default() @@ -304,6 +616,7 @@ def test_diff_options_default(self): else: self.assertEqual(expected, actual, '{} == {} ?'.format(attr, const)) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM") def test_diff_mode_consts(self): jvm = self.spark._jvm jmodes = jvm.uk.co.gresearch.spark.diff.DiffMode @@ -316,6 +629,21 @@ def test_diff_mode_consts(self): self.assertEqual(expected.toString(), actual.name, actual.name) self.assertIsNotNone(DiffMode.Default.name, jmodes.Default().toString()) + def test_diff_options_comparator_for(self): + cmp1 = DiffComparators.default() + cmp2 = DiffComparators.epsilon(0.01) + cmp3 = DiffComparators.string() + + opts = DiffOptions() \ + .with_column_name_comparator(cmp1, "abc", "def") \ + .with_data_type_comparator(cmp2, LongType()) \ + .with_default_comparator(cmp3) + + self.assertEqual(opts.comparator_for(StructField("abc", IntegerType())), cmp1) + self.assertEqual(opts.comparator_for(StructField("def", LongType())), cmp1) + self.assertEqual(opts.comparator_for(StructField("ghi", LongType())), cmp2) + self.assertEqual(opts.comparator_for(StructField("jkl", IntegerType())), cmp3) + def test_diff_fluent_setters(self): cmp1 = DiffComparators.default() cmp2 = DiffComparators.epsilon(0.01) @@ -376,17 +704,34 @@ def test_diff_fluent_setters(self): self.assertEqual(without_change.diff_mode, DiffMode.SideBySide) self.assertEqual(without_change.sparse_mode, True) - def test_diff_with_comparators(self): + def test_diff_with_epsilon_comparator(self): + # relative inclusive epsilon options = DiffOptions() \ - .with_column_name_comparator(DiffComparators.epsilon(0.1).as_relative(), 'val') - + .with_column_name_comparator(DiffComparators.epsilon(0.1).as_relative().as_inclusive(), 'val') diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect() expected = self.spark.createDataFrame(self.expected_diff) \ .withColumn("diff", when(col("id") == 1, "N").otherwise(col("diff"))) \ .collect() + self.assertEqual(expected, diff) + + # relative exclusive epsilon + options = DiffOptions() \ + .with_column_name_comparator(DiffComparators.epsilon(0.0909).as_relative().as_exclusive(), 'val') + diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect() + self.assertEqual(self.expected_diff, diff) + # absolute inclusive epsilon + options = DiffOptions() \ + .with_column_name_comparator(DiffComparators.epsilon(0.10000000000000009).as_absolute().as_inclusive(), 'val') + diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect() self.assertEqual(expected, diff) + # absolute exclusive epsilon + options = DiffOptions() \ + .with_column_name_comparator(DiffComparators.epsilon(0.10000000000000009).as_absolute().as_exclusive(), 'val') + diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect() + self.assertEqual(self.expected_diff, diff) + def test_diff_options_with_duplicate_comparators(self): options = DiffOptions() \ .with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType()) \ @@ -410,17 +755,6 @@ def test_diff_options_with_duplicate_comparators(self): with self.assertRaisesRegex(ValueError, "A comparator for column names col1, col2 exists already."): options.with_column_name_comparator(DiffComparators.default(), 'col1', 'col2') - def test_diff_comparators(self): - jvm = self.spark.sparkContext._jvm - self.assertIsNotNone(DiffComparators.default()._to_java(jvm)) - self.assertIsNotNone(DiffComparators.nullSafeEqual()._to_java(jvm)) - self.assertIsNotNone(DiffComparators.epsilon(0.01)._to_java(jvm)) - self.assertIsNotNone(DiffComparators.string()._to_java(jvm)) - if jvm.uk.co.gresearch.spark.diff.comparator.DurationDiffComparator.isSupportedBySpark(): - self.assertIsNotNone(DiffComparators.duration('PT24H')._to_java(jvm)) - self.assertIsNotNone(DiffComparators.map(IntegerType(), LongType())._to_java(jvm)) - self.assertIsNotNone(DiffComparators.map(IntegerType(), LongType(), True)._to_java(jvm)) - if __name__ == '__main__': SparkTest.main(__file__) diff --git a/python/test/test_jvm.py b/python/test/test_jvm.py index 5ea02957..af1b0762 100644 --- a/python/test/test_jvm.py +++ b/python/test/test_jvm.py @@ -75,19 +75,6 @@ def test_get_jvm_check_java_pkg_is_installed(self): finally: spark._java_pkg_is_installed = is_installed - @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") - def test_diff(self): - for label, func in { - 'diff': lambda: self.df.diff(self.df), - 'diff_with_options': lambda: self.df.diff_with_options(self.df, DiffOptions()), - 'diffwith': lambda: self.df.diffwith(self.df), - 'diffwith_with_options': lambda: self.df.diffwith_with_options(self.df, DiffOptions()), - }.items(): - with self.subTest(label): - with self.assertRaises(RuntimeError) as e: - func() - self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) - @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") def test_dotnet_ticks(self): for label, func in { diff --git a/python/test/test_package.py b/python/test/test_package.py index 96d46c06..9ad43bfe 100644 --- a/python/test/test_package.py +++ b/python/test/test_package.py @@ -17,14 +17,22 @@ from subprocess import CalledProcessError from unittest import skipUnless, skipIf -from pyspark import __version__ -from pyspark.sql import Row +from pyspark import __version__, SparkContext +from pyspark.sql import Row, SparkSession, SQLContext from pyspark.sql.functions import col, count -from gresearch.spark import dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \ +from gresearch.spark import backticks, distinct_prefix_for, handle_configured_case_sensitivity, \ + list_contains_case_sensitivity, list_filter_case_sensitivity, list_diff_case_sensitivity, \ + dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \ timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, count_null from spark_common import SparkTest +try: + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + has_connect = True +except ImportError: + has_connect = False + POETRY_PYTHON_ENV = "POETRY_PYTHON" RICH_SOURCES_ENV = "RICH_SOURCES" @@ -105,6 +113,72 @@ def compare_dfs(self, expected, actual): [row.asDict() for row in expected.collect()] ) + def test_backticks(self): + self.assertEqual(backticks("column"), "column") + self.assertEqual(backticks("a.column"), "`a.column`") + self.assertEqual(backticks("`a.column`"), "`a.column`") + self.assertEqual(backticks("column", "a.field"), "column.`a.field`") + self.assertEqual(backticks("a.column", "a.field"), "`a.column`.`a.field`") + self.assertEqual(backticks("the.alias", "a.column", "a.field"), "`the.alias`.`a.column`.`a.field`") + + def test_distinct_prefix_for(self): + self.assertEqual(distinct_prefix_for([]), "_") + self.assertEqual(distinct_prefix_for(["a"]), "_") + self.assertEqual(distinct_prefix_for(["abc"]), "_") + self.assertEqual(distinct_prefix_for(["a", "bc", "def"]), "_") + self.assertEqual(distinct_prefix_for(["_a"]), "__") + self.assertEqual(distinct_prefix_for(["_abc"]), "__") + self.assertEqual(distinct_prefix_for(["a", "_bc", "__def"]), "___") + + def test_handle_configured_case_sensitivity(self): + case_sensitive = False + with self.subTest(case_sensitive=case_sensitive): + self.assertEqual(handle_configured_case_sensitivity('abc', case_sensitive), 'abc') + self.assertEqual(handle_configured_case_sensitivity('AbC', case_sensitive), 'abc') + self.assertEqual(handle_configured_case_sensitivity('ABC', case_sensitive), 'abc') + + case_sensitive = True + with self.subTest(case_sensitive=case_sensitive): + self.assertEqual(handle_configured_case_sensitivity('abc', case_sensitive), 'abc') + self.assertEqual(handle_configured_case_sensitivity('AbC', case_sensitive), 'AbC') + self.assertEqual(handle_configured_case_sensitivity('ABC', case_sensitive), 'ABC') + + def test_list_contains_case_sensitivity(self): + the_list = ['abc', 'Def', 'GhI', 'JKL'] + self.assertEqual(list_contains_case_sensitivity(the_list, 'a', case_sensitive=False), False) + self.assertEqual(list_contains_case_sensitivity(the_list, 'abc', case_sensitive=False), True) + self.assertEqual(list_contains_case_sensitivity(the_list, 'deF', case_sensitive=False), True) + self.assertEqual(list_contains_case_sensitivity(the_list, 'JKL', case_sensitive=False), True) + + self.assertEqual(list_contains_case_sensitivity(the_list, 'a', case_sensitive=True), False) + self.assertEqual(list_contains_case_sensitivity(the_list, 'abc', case_sensitive=True), True) + self.assertEqual(list_contains_case_sensitivity(the_list, 'deF', case_sensitive=True), False) + self.assertEqual(list_contains_case_sensitivity(the_list, 'JKL', case_sensitive=True), True) + + def test_list_filter_case_sensitivity(self): + the_list = ['abc', 'Def', 'GhI', 'JKL'] + self.assertEqual(list_filter_case_sensitivity(the_list, ['a'], case_sensitive=False), []) + self.assertEqual(list_filter_case_sensitivity(the_list, ['abc'], case_sensitive=False), ['abc']) + self.assertEqual(list_filter_case_sensitivity(the_list, ['deF'], case_sensitive=False), ['Def']) + self.assertEqual(list_filter_case_sensitivity(the_list, ['JKL'], case_sensitive=False), ['JKL']) + + self.assertEqual(list_filter_case_sensitivity(the_list, ['a'], case_sensitive=True), []) + self.assertEqual(list_filter_case_sensitivity(the_list, ['abc'], case_sensitive=True), ['abc']) + self.assertEqual(list_filter_case_sensitivity(the_list, ['deF'], case_sensitive=True), []) + self.assertEqual(list_filter_case_sensitivity(the_list, ['JKL'], case_sensitive=True), ['JKL']) + + def test_list_diff_case_sensitivity(self): + the_list = ['abc', 'Def', 'GhI', 'JKL'] + self.assertEqual(list_diff_case_sensitivity(the_list, ['a'], case_sensitive=False), the_list) + self.assertEqual(list_diff_case_sensitivity(the_list, ['abc'], case_sensitive=False), ['Def', 'GhI', 'JKL']) + self.assertEqual(list_diff_case_sensitivity(the_list, ['deF'], case_sensitive=False), ['abc', 'GhI', 'JKL']) + self.assertEqual(list_diff_case_sensitivity(the_list, ['JKL'], case_sensitive=False), ['abc', 'Def', 'GhI']) + + self.assertEqual(list_diff_case_sensitivity(the_list, ['a'], case_sensitive=True), the_list) + self.assertEqual(list_diff_case_sensitivity(the_list, ['abc'], case_sensitive=True), ['Def', 'GhI', 'JKL']) + self.assertEqual(list_diff_case_sensitivity(the_list, ['deF'], case_sensitive=True), the_list) + self.assertEqual(list_diff_case_sensitivity(the_list, ['JKL'], case_sensitive=True), ['abc', 'Def', 'GhI']) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_dotnet_ticks_to_timestamp(self): for column in ["tick", self.ticks.tick]: @@ -165,6 +239,14 @@ def test_count_null(self): ).collect() self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual) + def test_session(self): + self.assertIsNotNone(self.ticks.session()) + self.assertIsInstance(self.ticks.session(), tuple(([SparkSession] + ([ConnectSparkSession] if has_connect else [])))) + + def test_session_or_ctx(self): + self.assertIsNotNone(self.ticks.session_or_ctx()) + self.assertIsInstance(self.ticks.session_or_ctx(), tuple(([SparkSession, SQLContext] + ([ConnectSparkSession] if has_connect else [])))) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by create_temp_dir") def test_create_temp_dir(self): from pyspark import SparkFiles diff --git a/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala b/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala index 8636bac3..f569ee5b 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala @@ -100,7 +100,7 @@ class Differ(options: DiffOptions) { s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(", ")}" ) - val diffValueColumns = getDiffColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1).diff(pkColumns) + val diffValueColumns = getDiffValueColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1) if (Seq(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode)) { require( @@ -145,7 +145,7 @@ class Differ(options: DiffOptions) { private def getChangeColumn( existsColumnName: String, - valueVolumnsWithComparator: Seq[(String, DiffComparator)], + valueColumnsWithComparator: Seq[(String, DiffComparator)], left: Dataset[_], right: Dataset[_] ): Option[Column] = { @@ -153,7 +153,7 @@ class Differ(options: DiffOptions) { .map(changeColumn => when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null)) .otherwise( - Some(valueVolumnsWithComparator) + Some(valueColumnsWithComparator) .filter(_.nonEmpty) .map(columns => concat( @@ -171,15 +171,21 @@ class Differ(options: DiffOptions) { ) } - private[diff] def getDiffColumns[T, U]( + private[diff] def getDiffIdColumns[T, U]( + pkColumns: Seq[String], + left: Dataset[T], + right: Dataset[U], + ): Seq[(String, Column)] = { + pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c)) + } + + private[diff] def getDiffValueColumns[T, U]( pkColumns: Seq[String], valueColumns: Seq[String], left: Dataset[T], right: Dataset[U], ignoreColumns: Seq[String] ): Seq[(String, Column)] = { - val idColumns = pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c)) - val leftValueColumns = left.columns.filterIsInCaseSensitivity(valueColumns) val rightValueColumns = right.columns.filterIsInCaseSensitivity(valueColumns) @@ -230,7 +236,7 @@ class Differ(options: DiffOptions) { val prefixedLeftIgnoredColumns = leftIgnoredColumns.map(c => aliasLeft(c)) val prefixedRightIgnoredColumns = rightIgnoredColumns.map(c => aliasRight(c)) - val nonIdColumns = options.diffMode match { + options.diffMode match { case DiffMode.ColumnByColumn => valueColumns.flatMap(c => Seq( @@ -256,7 +262,16 @@ class Differ(options: DiffOptions) { else rightIgnoredColumns.map(alias(None, rightValues)) ) } - idColumns ++ nonIdColumns + } + + private[diff] def getDiffColumns[T, U]( + pkColumns: Seq[String], + valueColumns: Seq[String], + left: Dataset[T], + right: Dataset[U], + ignoreColumns: Seq[String] + ): Seq[(String, Column)] = { + getDiffIdColumns(pkColumns, left, right) ++ getDiffValueColumns(pkColumns, valueColumns, left, right, ignoreColumns) } private def doDiff[T, U]( @@ -271,14 +286,14 @@ class Differ(options: DiffOptions) { val pkColumns = if (idColumns.isEmpty) columns else idColumns val valueColumns = columns.diffCaseSensitivity(pkColumns) val valueStructFields = left.schema.fields.map(f => f.name -> f).toMap - val valueVolumnsWithComparator = valueColumns.map(c => c -> options.comparatorFor(valueStructFields(c))) + val valueColumnsWithComparator = valueColumns.map(c => c -> options.comparatorFor(valueStructFields(c))) val existsColumnName = distinctPrefixFor(left.columns) + "exists" val leftWithExists = left.withColumn(existsColumnName, lit(1)) val rightWithExists = right.withColumn(existsColumnName, lit(1)) val joinCondition = pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _) - val unChanged = valueVolumnsWithComparator + val unChanged = valueColumnsWithComparator .map { case (c, cmp) => cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c))) } @@ -294,7 +309,7 @@ class Differ(options: DiffOptions) { .as(options.diffColumn) val diffColumns = getDiffColumns(pkColumns, valueColumns, left, right, ignoreColumns).map(_._2) - val changeColumn = getChangeColumn(existsColumnName, valueVolumnsWithComparator, leftWithExists, rightWithExists) + val changeColumn = getChangeColumn(existsColumnName, valueColumnsWithComparator, leftWithExists, rightWithExists) // turn this column into a sequence of one or none column so we can easily concat it below with diffActionColumn and diffColumns .map(Seq(_)) .getOrElse(Seq.empty[Column]) diff --git a/src/main/scala/uk/co/gresearch/spark/package.scala b/src/main/scala/uk/co/gresearch/spark/package.scala index 682e1063..9a6453f2 100644 --- a/src/main/scala/uk/co/gresearch/spark/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/package.scala @@ -40,6 +40,8 @@ package object spark extends Logging with SparkVersion with BuildVersion { * distinct prefix */ private[spark] def distinctPrefixFor(existing: Seq[String]): String = { + // count number of suffix _ for each existing column name + // return string with one more _ than that "_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1) } diff --git a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala index 1e8a2b78..270e12db 100644 --- a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala @@ -764,8 +764,17 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { } } + test("distinct prefix for") { + assert(distinctPrefixFor(Seq.empty[String]) === "_") + assert(distinctPrefixFor(Seq("a")) === "_") + assert(distinctPrefixFor(Seq("abc")) === "_") + assert(distinctPrefixFor(Seq("a", "bc", "def")) === "_") + assert(distinctPrefixFor(Seq("_a")) === "__") + assert(distinctPrefixFor(Seq("_abc")) === "__") + assert(distinctPrefixFor(Seq("a", "_bc", "__def")) === "___") + } + test("Spark temp dir") { - import uk.co.gresearch.spark.createTemporaryDir val dir = createTemporaryDir("test") assert(Paths.get(dir).toAbsolutePath.toString.startsWith(SparkFiles.getRootDirectory())) } diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala index 1fd8f2c4..da61d31b 100644 --- a/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala @@ -28,6 +28,7 @@ case class Value(id: Int, value: Option[String]) case class Value2(id: Int, seq: Option[Int], value: Option[String]) case class Value3(id: Int, left_value: String, right_value: String, value: String) case class Value4(id: Int, diff: String) +case class Value4b(id: Int, change: String) case class Value5(first_id: Int, id: String) case class Value6(id: Int, label: String) case class Value7(id: Int, value: Option[String], label: Option[String]) @@ -332,16 +333,6 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { lazy val expectedDiffWith8and9up: Seq[(String, Value8, Value9up)] = expectedDiffWith8and9.map(t => t.copy(_3 = Option(t._3).map(v => Value9up(v.id, v.seq, v.value, v.info)).orNull)) - test("distinct prefix for") { - assert(distinctPrefixFor(Seq.empty[String]) === "_") - assert(distinctPrefixFor(Seq("a")) === "_") - assert(distinctPrefixFor(Seq("abc")) === "_") - assert(distinctPrefixFor(Seq("a", "bc", "def")) === "_") - assert(distinctPrefixFor(Seq("_a")) === "__") - assert(distinctPrefixFor(Seq("_abc")) === "__") - assert(distinctPrefixFor(Seq("a", "_bc", "__def")) === "___") - } - test("diff dataframe with duplicate columns") { val df = Seq(1).toDF("id").select($"id", $"id") @@ -725,6 +716,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { "The column prefixes 'a' and 'b', together with these non-id columns " + "must not produce the diff column name 'a_value': value" ) + doTestRequirement( + left.diff(right, options.withDiffColumn("b_value"), "id"), + "The column prefixes 'a' and 'b', together with these non-id columns " + + "must not produce the diff column name 'b_value': value" + ) } test("diff with left-side mode where non-id column would produce diff column name") { @@ -739,7 +735,7 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { test("diff with right-side mode where non-id column would produce diff column name") { val options = DiffOptions.default - .withDiffColumn("a_value") + .withDiffColumn("b_value") .withLeftColumnPrefix("a") .withRightColumnPrefix("b") .withDiffMode(DiffMode.RightSide) @@ -759,6 +755,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { "The column prefixes 'A' and 'B', together with these non-id columns " + "must not produce the diff column name 'a_value': value" ) + doTestRequirement( + left.diff(right, options.withDiffColumn("b_value"), "id"), + "The column prefixes 'A' and 'B', together with these non-id columns " + + "must not produce the diff column name 'b_value': value" + ) } } @@ -806,6 +807,63 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { } } + test("diff with id column change in T") { + val left = Seq(Value4b(1, "change")).toDS() + val right = Seq(Value4b(1, "Change")).toDS() + + val options = DiffOptions.default.withChangeColumn("change") + + doTestRequirement( + left.diff(right, options), + "The id columns must not contain the change column name 'change': id, change" + ) + doTestRequirement( + left.diff(right, options, "change"), + "The id columns must not contain the change column name 'change': change" + ) + doTestRequirement( + left.diff(right, options, "change", "id"), + "The id columns must not contain the change column name 'change': change, id" + ) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + doTestRequirement( + left + .withColumnRenamed("change", "Change") + .diff(right.withColumnRenamed("change", "Change"), options, "Change", "id"), + "The id columns must not contain the change column name 'change': Change, id" + ) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + left + .withColumnRenamed("change", "Change") + .diff(right.withColumnRenamed("change", "Change"), options, "Change", "id") + } + } + + test("diff with non-id column change in T") { + val left = Seq(Value4b(1, "change")).toDS() + val right = Seq(Value4b(1, "Change")).toDS() + + val options = DiffOptions.default.withChangeColumn("change") + + val actual = left.diff(right, options, "id") + val expectedColumns = Seq( + "diff", + "change", + "id", + "left_change", + "right_change" + ) + val expectedDiff = Seq( + Row("C", Seq("change"), 1, "change", "Change") + ) + + assert(actual.columns === expectedColumns) + assert(actual.collect() === expectedDiff) + } + test("diff where non-id column produces change column name") { val options = DiffOptions.default .withChangeColumn("a_value")