Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds partition support to SparkHiveDataSet #745

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
## Major features and improvements
* Kedro plugins can now override built-in CLI commands.
* Added a `before_command_run` hook for plugins to add extra behaviour before Kedro CLI commands run.
* Added support for `partition` to `SparkHiveDataSet`.

## Bug fixes and other changes
* `TemplatedConfigLoader` now correctly inserts default values when no globals are supplied.
Expand All @@ -32,6 +33,7 @@
## Minor breaking changes to the API

## Thanks for supporting contributions
[Breno Silva](https://github.com/brendalf)

# Release 0.17.2

Expand Down
35 changes: 30 additions & 5 deletions kedro/extras/datasets/spark/spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
)


# pylint: disable=too-many-instance-attributes
class SparkHiveDataSet(AbstractDataSet):
"""``SparkHiveDataSet`` loads and saves Spark dataframes stored on Hive.
This data set also handles some incompatible file types such as using partitioned parquet on
Expand Down Expand Up @@ -121,8 +122,13 @@ class SparkHiveDataSet(AbstractDataSet):
>>> reloaded.take(4)
"""

def __init__(
self, database: str, table: str, write_mode: str, table_pk: List[str] = None
def __init__( # pylint: disable=too-many-arguments
self,
database: str,
table: str,
write_mode: str,
table_pk: List[str] = None,
partition: str = None,
) -> None:
"""Creates a new instance of ``SparkHiveDataSet``.

Expand All @@ -132,7 +138,8 @@ def __init__(
write_mode: ``insert``, ``upsert`` or ``overwrite`` are supported.
table_pk: If performing an upsert, this identifies the primary key columns used to
resolve preexisting data. Is required for ``write_mode="upsert"``.

partition: Define in which partition the data should be inserted. It only works
if the table already exists.
Raises:
DataSetError: Invalid configuration supplied
"""
Expand All @@ -151,6 +158,13 @@ def __init__(
self._database = database
self._table = table
self._stage_table = "_temp_" + table
self._partition = partition

# get the name of each partition
self._partitions: List[str] = []
if self._partition is not None:
for partition_set in self._partition.split(","):
self._partitions.append(partition_set.split("=")[0].strip())
Comment on lines +161 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to just directly make the function argument partition: List[str] = None, as is done for table_pk? This would do away with the need for string manipulation here.

Copy link
Author

@brendalf brendalf Apr 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, but with just the list of column names, we can't identify a partition. Maybe we can make partition get a list of tuples, where each position of the list identifies the column name and value. What do you think?

Copy link
Contributor

@antonymilne antonymilne Apr 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. That could indeed work, but now I'm wondering exactly what the partition string in the insert could look like. e.g. we can have partition_1 = x, partition_2 = y but we could also just have partition_1, partition_2 right? Do you know what is the full specification of the syntax of this string?

Now I'm actually also wondering why SparkHiveDataSet is using dynamically built SQL queries in the first place rather than the Python API, which would presumably make this sort of thing much easier. Is that impossible when doing the hive thing? @jiriklein

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AntonyMilneQB, sorry for the late reply.
That's right, the partition string will be something like partition_1 = x, partition_2 = y if the partition type is static or partition_1, partition_2 if the partition is dynamic.


# self._table_columns is set up in _save() to speed up initialization
self._table_columns = [] # type: List[str]
Expand All @@ -169,6 +183,7 @@ def _get_spark() -> SparkSession:

def _create_empty_hive_table(self, data):
data.createOrReplaceTempView("tmp")

self._get_spark().sql(
f"create table {self._database}.{self._table} select * from tmp limit 1" # nosec
)
Expand All @@ -188,7 +203,9 @@ def _save(self, data: DataFrame) -> None:
self._create_empty_hive_table(data)
self._table_columns = data.columns
else:
self._table_columns = self._load().columns
self._table_columns = list(
set(self._load().columns) - set(self._partitions)
)
if self._write_mode == "upsert":
non_existent_columns = set(self._table_pk) - set(self._table_columns)
if non_existent_columns:
Expand All @@ -209,8 +226,13 @@ def _save(self, data: DataFrame) -> None:
def _insert_save(self, data: DataFrame) -> None:
data.createOrReplaceTempView("tmp")
columns = ", ".join(self._table_columns)

partition = f"partition ({self._partition})" if self._partition else ""
self._get_spark().sql(
f"insert into {self._database}.{self._table} select {columns} from tmp" # nosec
f"""
insert into {self._database}.{self._table} {partition}
select {columns} from tmp
""" # nosec
)

def _upsert_save(self, data: DataFrame) -> None:
Expand Down Expand Up @@ -242,7 +264,10 @@ def _overwrite_save(self, data: DataFrame) -> None:

def _validate_save(self, data: DataFrame):
hive_dtypes = set(self._load().dtypes)
if self._partitions:
hive_dtypes = {(k, v) for k, v in hive_dtypes if k not in self._partitions}
data_dtypes = set(data.dtypes)

if data_dtypes != hive_dtypes:
new_cols = data_dtypes - hive_dtypes
missing_cols = hive_dtypes - data_dtypes
Expand Down
76 changes: 76 additions & 0 deletions tests/extras/datasets/spark/test_spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,79 @@ def test_read_from_non_existent_table(self):
match="Requested table not found: default_1.table_doesnt_exist",
):
dataset.load()

def test_insert_empty_table_with_partition(self, spark_hive_session):
spark_hive_session.sql(
"""
create table default_1.test_insert_empty_table_with_partition
(name string, age integer)
partitioned by (ref integer)
"""
).take(1)
dataset = SparkHiveDataSet(
database="default_1",
table="test_insert_empty_table_with_partition",
write_mode="insert",
partition="ref = 1",
)
dataset.save(_generate_spark_df_one())
assert_df_equal(dataset.load().drop("ref"), _generate_spark_df_one())

def test_insert_to_non_existent_table_with_partition(self):
dataset = SparkHiveDataSet(
database="default_1",
table="table_with_partition_doesnt_exist",
write_mode="insert",
partition="ref = 1",
)
with pytest.raises(
DataSetError,
match=r"Failed while saving data to data set SparkHiveDataSet"
r"\(database\=default_1, table\=table_with_partition_doesnt_exist, "
r"table_pk\=\[\], write_mode\=insert\)\.\n"
r"ref is not a valid partition column in table"
r"\`default_1\`\.`table_with_partition_doesnt_exist`",
):
dataset.save(_generate_spark_df_one())

def test_upsert_not_empty_table_with_partition(self, spark_hive_session):
spark_hive_session.sql(
"""
create table default_1.test_upsert_not_empty_table_with_partition
(name string, age integer)
partitioned by (ref integer)
"""
).take(1)
dataset = SparkHiveDataSet(
database="default_1",
table="test_upsert_not_empty_table_with_partition",
write_mode="upsert",
partition="ref = 1",
table_pk=["name"],
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_upsert())

assert_df_equal(
dataset.load().drop("ref").sort("name"),
_generate_spark_df_upsert_expected().sort("name"),
)

def test_overwrite_not_empty_table_with_partition(self, spark_hive_session):
spark_hive_session.sql(
"""
create table default_1.test_overwrite_not_empty_table_with_partition
(name string, age integer)
partitioned by (ref integer)
"""
).take(1)
dataset = SparkHiveDataSet(
database="default_1",
table="test_overwrite_not_empty_table_with_partition",
write_mode="overwrite",
partition="ref = 1",
table_pk=["name"],
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one())
assert_df_equal(dataset.load().drop("ref"), _generate_spark_df_one())