diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index 9b5ac1e6c7..69447fb07b 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -425,12 +425,15 @@ def write_training_dataset( split_dataset = self._split_df( query_obj, training_dataset, read_options=read_options ) + for key in split_dataset: + if training_dataset.coalesce: + split_dataset[key] = split_dataset[key].coalesce(1) + + split_dataset[key] = split_dataset[key].cache() + transformation_function_engine.TransformationFunctionEngine.populate_builtin_transformation_functions( training_dataset, feature_view_obj, split_dataset ) - if training_dataset.coalesce: - for key in split_dataset: - split_dataset[key] = split_dataset[key].coalesce(1) return self._write_training_dataset_splits( training_dataset, split_dataset, write_options, save_mode, to_df=to_df ) @@ -539,6 +542,8 @@ def _write_training_dataset_single( save_mode ).save(path) + feature_dataframe.unpersist() + def read(self, storage_connector, data_format, read_options, location): if isinstance(location, str): if data_format.lower() in ["delta", "parquet", "hudi", "orc", "bigquery"]: diff --git a/python/tests/engine/test_spark.py b/python/tests/engine/test_spark.py index b4725720ad..736e6064b6 100644 --- a/python/tests/engine/test_spark.py +++ b/python/tests/engine/test_spark.py @@ -45,6 +45,7 @@ expectation_suite, training_dataset_feature, ) +from hsfs.core import training_dataset_engine from hsfs.engine import spark from hsfs.constructor import query, hudi_feature_group_alias from hsfs.client import exceptions @@ -1570,6 +1571,123 @@ def test_write_training_dataset(self, mocker): assert mock_spark_engine_write_training_dataset_single.call_count == 0 assert mock_spark_engine_write_training_dataset_splits.call_count == 0 + def test_write_training_dataset_to_df(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.engine.get_type", return_value="python") + mocker.patch("hsfs.client.get_instance") + + spark_engine = spark.Engine() + + jsonq = backend_fixtures["query"]["get"]["response"] + q = query.Query.from_response_json(jsonq) + + mock_query_read = mocker.patch("hsfs.constructor.query.Query.read") + d = { + "col_0": [1, 2], + "col_1": ["test_1", "test_2"], + "col_2": [3, 4], + "event_time": [1, 2], + } + df = pd.DataFrame(data=d) + query_df = spark_engine._spark_session.createDataFrame(df) + mock_query_read.side_effect = [query_df] + + td = training_dataset.TrainingDataset( + name="test", + version=None, + splits={}, + event_start_time=None, + event_end_time=None, + description="test", + storage_connector=None, + featurestore_id=10, + data_format="tsv", + location="", + statistics_config=None, + training_dataset_type=training_dataset.TrainingDataset.IN_MEMORY, + extra_filter=None, + transformation_functions={}, + ) + + # Act + df_returned = spark_engine.write_training_dataset( + training_dataset=td, + query_obj=q, + user_write_options={}, + save_mode=training_dataset_engine.TrainingDatasetEngine.OVERWRITE, + read_options={}, + feature_view_obj=None, + to_df=True, + ) + + # Assert + assert set(df_returned.columns) == {"col_0", "col_1", "col_2", "event_time"} + assert df_returned.count() == 2 + assert df_returned.exceptAll(query_df).rdd.isEmpty() + + def test_write_training_dataset_split_to_df(self, mocker, backend_fixtures): + # Arrange + mocker.patch("hsfs.engine.get_type", return_value="python") + mocker.patch("hsfs.client.get_instance") + + spark_engine = spark.Engine() + + jsonq = backend_fixtures["query"]["get"]["response"] + q = query.Query.from_response_json(jsonq) + + mock_query_read = mocker.patch("hsfs.constructor.query.Query.read") + d = { + "col_0": [1, 2], + "col_1": ["test_1", "test_2"], + "col_2": [3, 4], + "event_time": [1, 2], + } + df = pd.DataFrame(data=d) + query_df = spark_engine._spark_session.createDataFrame(df) + mock_query_read.side_effect = [query_df] + + td = training_dataset.TrainingDataset( + name="test", + version=None, + splits={}, + test_size=0.5, + train_start=None, + train_end=None, + test_start=None, + test_end=None, + time_split_size=2, + description="test", + storage_connector=None, + featurestore_id=12, + data_format="tsv", + location="", + statistics_config=None, + training_dataset_type=training_dataset.TrainingDataset.IN_MEMORY, + extra_filter=None, + seed=1, + transformation_functions={}, + ) + + # Act + split_dfs_returned = spark_engine.write_training_dataset( + training_dataset=td, + query_obj=q, + user_write_options={}, + save_mode=training_dataset_engine.TrainingDatasetEngine.OVERWRITE, + read_options={}, + feature_view_obj=None, + to_df=True, + ) + + # Assert + sum_rows = 0 + for key in split_dfs_returned: + df_returned = split_dfs_returned[key] + assert set(df_returned.columns) == {"col_0", "col_1", "col_2", "event_time"} + sum_rows += df_returned.count() + + assert sum_rows == 2 + def test_write_training_dataset_query(self, mocker): # Arrange mocker.patch("hsfs.engine.get_type") @@ -1967,9 +2085,11 @@ def test_random_split(self, mocker): # Assert assert list(result) == ["test_split1", "test_split2"] + sum_rows = 0 for column in list(result): assert result[column].schema == spark_df.schema - assert not result[column].rdd.isEmpty() + sum_rows += result[column].count() + assert sum_rows == 6 def test_time_series_split(self, mocker): # Arrange diff --git a/python/tests/fixtures/training_dataset_fixtures.json b/python/tests/fixtures/training_dataset_fixtures.json index 2d96b827c7..026c06c242 100644 --- a/python/tests/fixtures/training_dataset_fixtures.json +++ b/python/tests/fixtures/training_dataset_fixtures.json @@ -9,7 +9,7 @@ "location": "test_location", "event_start_time": 1646438400000, "event_end_time": 1646697600000, - "coalesce": "test_coalesce", + "coalesce": true, "description": "test_description", "storage_connector": { "type": "featurestoreJdbcConnectorDTO", @@ -43,8 +43,8 @@ "end_time": "test_end_time" } ], - "validation_size": 2, - "test_size": 3, + "validation_size": 0.0, + "test_size": 0.5, "train_start": 4, "train_end": 5, "validation_start": 6, diff --git a/python/tests/test_training_dataset.py b/python/tests/test_training_dataset.py index 9750b90c1d..caf79916e5 100644 --- a/python/tests/test_training_dataset.py +++ b/python/tests/test_training_dataset.py @@ -43,15 +43,15 @@ def test_from_response_json(self, mocker, backend_fixtures): assert td.data_format == "hudi" assert td._start_time == 1646438400000 assert td._end_time == 1646697600000 - assert td.validation_size == 2 - assert td.test_size == 3 + assert td.validation_size == 0.0 + assert td.test_size == 0.5 assert td.train_start == 4 assert td.train_end == 5 assert td.validation_start == 6 assert td.validation_end == 7 assert td.test_start == 8 assert td.test_end == 9 - assert td.coalesce == "test_coalesce" + assert td.coalesce is True assert td.seed == 123 assert td.location == "test_location" assert td._from_query == "test_from_query"