diff --git a/RELEASE.md b/RELEASE.md index 92e4fb89ac..a410754fbb 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -19,6 +19,7 @@ ## Major features and improvements * Kedro plugins can now override built-in CLI commands. * Added a `before_command_run` hook for plugins to add extra behaviour before Kedro CLI commands run. +* Added support for `partition` to `SparkHiveDataSet`. ## Bug fixes and other changes * `TemplatedConfigLoader` now correctly inserts default values when no globals are supplied. @@ -32,6 +33,7 @@ ## Minor breaking changes to the API ## Thanks for supporting contributions +[Breno Silva](https://github.com/brendalf) # Release 0.17.2 diff --git a/kedro/extras/datasets/spark/spark_hive_dataset.py b/kedro/extras/datasets/spark/spark_hive_dataset.py index 6b08166c21..7efc682c62 100644 --- a/kedro/extras/datasets/spark/spark_hive_dataset.py +++ b/kedro/extras/datasets/spark/spark_hive_dataset.py @@ -85,6 +85,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) +# pylint: disable=too-many-instance-attributes class SparkHiveDataSet(AbstractDataSet): """``SparkHiveDataSet`` loads and saves Spark dataframes stored on Hive. This data set also handles some incompatible file types such as using partitioned parquet on @@ -121,8 +122,13 @@ class SparkHiveDataSet(AbstractDataSet): >>> reloaded.take(4) """ - def __init__( - self, database: str, table: str, write_mode: str, table_pk: List[str] = None + def __init__( # pylint: disable=too-many-arguments + self, + database: str, + table: str, + write_mode: str, + table_pk: List[str] = None, + partition: str = None, ) -> None: """Creates a new instance of ``SparkHiveDataSet``. @@ -132,7 +138,8 @@ def __init__( write_mode: ``insert``, ``upsert`` or ``overwrite`` are supported. table_pk: If performing an upsert, this identifies the primary key columns used to resolve preexisting data. Is required for ``write_mode="upsert"``. - + partition: Define in which partition the data should be inserted. It only works + if the table already exists. Raises: DataSetError: Invalid configuration supplied """ @@ -151,6 +158,13 @@ def __init__( self._database = database self._table = table self._stage_table = "_temp_" + table + self._partition = partition + + # get the name of each partition + self._partitions: List[str] = [] + if self._partition is not None: + for partition_set in self._partition.split(","): + self._partitions.append(partition_set.split("=")[0].strip()) # self._table_columns is set up in _save() to speed up initialization self._table_columns = [] # type: List[str] @@ -169,6 +183,7 @@ def _get_spark() -> SparkSession: def _create_empty_hive_table(self, data): data.createOrReplaceTempView("tmp") + self._get_spark().sql( f"create table {self._database}.{self._table} select * from tmp limit 1" # nosec ) @@ -188,7 +203,9 @@ def _save(self, data: DataFrame) -> None: self._create_empty_hive_table(data) self._table_columns = data.columns else: - self._table_columns = self._load().columns + self._table_columns = list( + set(self._load().columns) - set(self._partitions) + ) if self._write_mode == "upsert": non_existent_columns = set(self._table_pk) - set(self._table_columns) if non_existent_columns: @@ -209,8 +226,13 @@ def _save(self, data: DataFrame) -> None: def _insert_save(self, data: DataFrame) -> None: data.createOrReplaceTempView("tmp") columns = ", ".join(self._table_columns) + + partition = f"partition ({self._partition})" if self._partition else "" self._get_spark().sql( - f"insert into {self._database}.{self._table} select {columns} from tmp" # nosec + f""" + insert into {self._database}.{self._table} {partition} + select {columns} from tmp + """ # nosec ) def _upsert_save(self, data: DataFrame) -> None: @@ -242,7 +264,10 @@ def _overwrite_save(self, data: DataFrame) -> None: def _validate_save(self, data: DataFrame): hive_dtypes = set(self._load().dtypes) + if self._partitions: + hive_dtypes = {(k, v) for k, v in hive_dtypes if k not in self._partitions} data_dtypes = set(data.dtypes) + if data_dtypes != hive_dtypes: new_cols = data_dtypes - hive_dtypes missing_cols = hive_dtypes - data_dtypes diff --git a/tests/extras/datasets/spark/test_spark_hive_dataset.py b/tests/extras/datasets/spark/test_spark_hive_dataset.py index 8db7b9aa00..dc2515f926 100644 --- a/tests/extras/datasets/spark/test_spark_hive_dataset.py +++ b/tests/extras/datasets/spark/test_spark_hive_dataset.py @@ -322,3 +322,79 @@ def test_read_from_non_existent_table(self): match="Requested table not found: default_1.table_doesnt_exist", ): dataset.load() + + def test_insert_empty_table_with_partition(self, spark_hive_session): + spark_hive_session.sql( + """ + create table default_1.test_insert_empty_table_with_partition + (name string, age integer) + partitioned by (ref integer) + """ + ).take(1) + dataset = SparkHiveDataSet( + database="default_1", + table="test_insert_empty_table_with_partition", + write_mode="insert", + partition="ref = 1", + ) + dataset.save(_generate_spark_df_one()) + assert_df_equal(dataset.load().drop("ref"), _generate_spark_df_one()) + + def test_insert_to_non_existent_table_with_partition(self): + dataset = SparkHiveDataSet( + database="default_1", + table="table_with_partition_doesnt_exist", + write_mode="insert", + partition="ref = 1", + ) + with pytest.raises( + DataSetError, + match=r"Failed while saving data to data set SparkHiveDataSet" + r"\(database\=default_1, table\=table_with_partition_doesnt_exist, " + r"table_pk\=\[\], write_mode\=insert\)\.\n" + r"ref is not a valid partition column in table" + r"\`default_1\`\.`table_with_partition_doesnt_exist`", + ): + dataset.save(_generate_spark_df_one()) + + def test_upsert_not_empty_table_with_partition(self, spark_hive_session): + spark_hive_session.sql( + """ + create table default_1.test_upsert_not_empty_table_with_partition + (name string, age integer) + partitioned by (ref integer) + """ + ).take(1) + dataset = SparkHiveDataSet( + database="default_1", + table="test_upsert_not_empty_table_with_partition", + write_mode="upsert", + partition="ref = 1", + table_pk=["name"], + ) + dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_upsert()) + + assert_df_equal( + dataset.load().drop("ref").sort("name"), + _generate_spark_df_upsert_expected().sort("name"), + ) + + def test_overwrite_not_empty_table_with_partition(self, spark_hive_session): + spark_hive_session.sql( + """ + create table default_1.test_overwrite_not_empty_table_with_partition + (name string, age integer) + partitioned by (ref integer) + """ + ).take(1) + dataset = SparkHiveDataSet( + database="default_1", + table="test_overwrite_not_empty_table_with_partition", + write_mode="overwrite", + partition="ref = 1", + table_pk=["name"], + ) + dataset.save(_generate_spark_df_one()) + dataset.save(_generate_spark_df_one()) + assert_df_equal(dataset.load().drop("ref"), _generate_spark_df_one())