Skip to content

Commit

Permalink
Support diff for Spark Connect (Dataset API) (#251)
Browse files Browse the repository at this point in the history
Adds support for diff for Spark Connect environments by implementing dataset logic in Python.
  • Loading branch information
EnricoMi authored Nov 4, 2024
1 parent 4d35579 commit e50ddfb
Show file tree
Hide file tree
Showing 13 changed files with 912 additions and 128 deletions.
2 changes: 0 additions & 2 deletions DIFF.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,6 @@ The latter variant is prefixed with `_with_options`.
* `def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame:`
* `def diffwith(self: DataFrame, other: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`

Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).

## Diff Spark application

There is also a Spark application that can be used to create a diff DataFrame. The application reads two DataFrames
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python:

**[Diff](DIFF.md) [<sup>[*]</sup>](#spark-connect-server):** A `diff` transformation and application for `Dataset`s that computes the differences between
**[Diff](DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between
two datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other.

**[SortedGroups](GROUPS.md):** A `groupByKey` transformation that groups rows by a key while providing
Expand Down
49 changes: 48 additions & 1 deletion python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Union, List, Optional, Mapping, TYPE_CHECKING
from typing import Any, Union, List, Optional, Mapping, Iterable, TYPE_CHECKING

from py4j.java_gateway import JVMView, JavaObject
from pyspark import __version__
Expand All @@ -30,6 +30,7 @@
from pyspark.sql import DataFrame, DataFrameReader, SQLContext
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.context import SQLContext
from pyspark import SparkConf
from pyspark.sql.functions import col, count, lit, when
from pyspark.sql.session import SparkSession
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -106,6 +107,46 @@ def _to_map(jvm: JVMView, map: Mapping[Any, Any]) -> JavaObject:
return jvm.scala.collection.JavaConverters.mapAsScalaMap(map)


def backticks(*name_parts: str) -> str:
return '.'.join([f'`{part}`'
if '.' in part and not part.startswith('`') and not part.endswith('`')
else part
for part in name_parts])


def distinct_prefix_for(existing: List[str]) -> str:
# count number of suffix _ for each existing column name
length = 1
if existing:
length = max([len(name) - len(name.lstrip('_')) for name in existing]) + 1
# return string with one more _ than that
return '_' * length


def handle_configured_case_sensitivity(column_name: str, case_sensitive: bool) -> str:
"""
Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is
deactivated, it lower-cases the given column name and no-ops otherwise.
"""
if case_sensitive:
return column_name
return column_name.lower()


def list_contains_case_sensitivity(column_names: Iterable[str], columnName: str, case_sensitive: bool) -> bool:
return handle_configured_case_sensitivity(columnName, case_sensitive) in [handle_configured_case_sensitivity(c, case_sensitive) for c in column_names]


def list_filter_case_sensitivity(column_names: Iterable[str], filter: Iterable[str], case_sensitive: bool) -> List[str]:
filter_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in filter}
return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) in filter_set]


def list_diff_case_sensitivity(column_names: Iterable[str], other: Iterable[str], case_sensitive: bool) -> List[str]:
other_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in other}
return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) not in other_set]


def dotnet_ticks_to_timestamp(tick_column: Union[str, Column]) -> Column:
"""
Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be
Expand Down Expand Up @@ -386,12 +427,18 @@ def with_row_numbers(self: DataFrame,
ConnectDataFrame.with_row_numbers = with_row_numbers


def session(self: DataFrame) -> SparkSession:
return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx.sparkSession


def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]:
return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx


DataFrame.session = session
DataFrame.session_or_ctx = session_or_ctx
if has_connect:
ConnectDataFrame.session = session
ConnectDataFrame.session_or_ctx = session_or_ctx


Expand Down
Loading

0 comments on commit e50ddfb

Please sign in to comment.