Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HOPSWORKS-3323] Fix TDS creation in PySpark client and add excplicit caching #784

Merged
merged 5 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,17 @@ def write_training_dataset(
else:
raise ValueError("Dataset should be a query.")

if training_dataset.coalesce:
dataset = dataset.coalesce(1)

dataset = dataset.cache()
Copy link
Contributor

@kennethmhc kennethmhc Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no split. Why cache here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I did this to cache the query. If we do the split we cache the split result instead. I would have liked to cache the query as well when we split, but I was a bit concerned about the potential memory consumption when caching before and after the split.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is no split, I think we should not cache the result and users can cache themselves. The purpose of caching is to return consistent result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well okay, the reasoning behind it was to prevent 2x execution in the case of 1. having transformation fuctions and need to calcualte statistics and 2. have to write the df to disk. But I agree that placing the caching here without checking for this particular case is not great. Shall we still keep it for the special case though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right. Need to calculate statistics. We should cache then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a Jira here to add it later https://hopsworks.atlassian.net/browse/FSTORE-317


transformation_function_engine.TransformationFunctionEngine.populate_builtin_transformation_functions(
training_dataset, feature_view_obj, dataset
)
if training_dataset.coalesce:
dataset = dataset.coalesce(1)
path = training_dataset.location + "/" + training_dataset.name
return self._write_training_dataset_single(
training_dataset,
training_dataset.transformation_functions,
dataset,
training_dataset.storage_connector,
training_dataset.data_format,
Expand All @@ -425,12 +428,15 @@ def write_training_dataset(
split_dataset = self._split_df(
query_obj, training_dataset, read_options=read_options
)
for key in split_dataset:
if training_dataset.coalesce:
split_dataset[key] = split_dataset[key].coalesce(1)

split_dataset[key] = split_dataset[key].cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we not cache before spliting?

Copy link
Contributor Author

@tdoehmen tdoehmen Sep 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This article here suggests to add it after https://medium.com/udemy-engineering/pyspark-under-the-hood-randomsplit-and-sample-inconsistencies-examined-7c6ec62644bc. We need it afterwards, because otherwise the randomSplit() will be executed twice (once for transformation function statistincs, once for writing). If no seed is set for the randomSplit, the splits are potentially different between transfo. stats and writing. Whether it is worth caching it before as well that is debatable. randomSplit scans the data once while creating the splits (because it samples while passing over it), so for the split itself we only have one pass. In the training dataset statistics we do a df.head() to determine the length of the dataframe. If we want to cache we would need to do it before that. But hat should be a separate PR in the future.


transformation_function_engine.TransformationFunctionEngine.populate_builtin_transformation_functions(
training_dataset, feature_view_obj, split_dataset
)
if training_dataset.coalesce:
for key in split_dataset:
split_dataset[key] = split_dataset[key].coalesce(1)
return self._write_training_dataset_splits(
training_dataset, split_dataset, write_options, save_mode, to_df=to_df
)
Expand Down Expand Up @@ -499,7 +505,7 @@ def _write_training_dataset_splits(
for split_name, feature_dataframe in feature_dataframes.items():
split_path = training_dataset.location + "/" + str(split_name)
feature_dataframes[split_name] = self._write_training_dataset_single(
training_dataset,
training_dataset.transformation_functions,
feature_dataframes[split_name],
training_dataset.storage_connector,
training_dataset.data_format,
Expand Down Expand Up @@ -539,6 +545,8 @@ def _write_training_dataset_single(
save_mode
).save(path)

feature_dataframe.unpersist()

def read(self, storage_connector, data_format, read_options, location):
if isinstance(location, str):
if data_format.lower() in ["delta", "parquet", "hudi", "orc", "bigquery"]:
Expand Down
124 changes: 123 additions & 1 deletion python/tests/engine/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
expectation_suite,
training_dataset_feature,
)
from hsfs.core import training_dataset_engine
from hsfs.engine import spark
from hsfs.constructor import query, hudi_feature_group_alias
from hsfs.client import exceptions
Expand Down Expand Up @@ -1570,6 +1571,123 @@ def test_write_training_dataset(self, mocker):
assert mock_spark_engine_write_training_dataset_single.call_count == 0
assert mock_spark_engine_write_training_dataset_splits.call_count == 0

def test_write_training_dataset_to_df(self, mocker, backend_fixtures):
# Arrange
mocker.patch("hsfs.engine.get_type", return_value="python")
mocker.patch("hsfs.client.get_instance")

spark_engine = spark.Engine()

jsonq = backend_fixtures["query"]["get"]["response"]
q = query.Query.from_response_json(jsonq)

mock_query_read = mocker.patch("hsfs.constructor.query.Query.read")
d = {
"col_0": [1, 2],
"col_1": ["test_1", "test_2"],
"col_2": [3, 4],
"event_time": [1, 2],
}
df = pd.DataFrame(data=d)
query_df = spark_engine._spark_session.createDataFrame(df)
mock_query_read.side_effect = [query_df]

td = training_dataset.TrainingDataset(
name="test",
version=None,
splits={},
event_start_time=None,
event_end_time=None,
description="test",
storage_connector=None,
featurestore_id=10,
data_format="tsv",
location="",
statistics_config=None,
training_dataset_type=training_dataset.TrainingDataset.IN_MEMORY,
extra_filter=None,
transformation_functions={},
)

# Act
df_returned = spark_engine.write_training_dataset(
training_dataset=td,
query_obj=q,
user_write_options={},
save_mode=training_dataset_engine.TrainingDatasetEngine.OVERWRITE,
read_options={},
feature_view_obj=None,
to_df=True,
)

# Assert
assert set(df_returned.columns) == {"col_0", "col_1", "col_2", "event_time"}
assert df_returned.count() == 2
assert df_returned.exceptAll(query_df).rdd.isEmpty()

def test_write_training_dataset_split_to_df(self, mocker, backend_fixtures):
# Arrange
mocker.patch("hsfs.engine.get_type", return_value="python")
mocker.patch("hsfs.client.get_instance")

spark_engine = spark.Engine()

jsonq = backend_fixtures["query"]["get"]["response"]
q = query.Query.from_response_json(jsonq)

mock_query_read = mocker.patch("hsfs.constructor.query.Query.read")
d = {
"col_0": [1, 2],
"col_1": ["test_1", "test_2"],
"col_2": [3, 4],
"event_time": [1, 2],
}
df = pd.DataFrame(data=d)
query_df = spark_engine._spark_session.createDataFrame(df)
mock_query_read.side_effect = [query_df]

td = training_dataset.TrainingDataset(
name="test",
version=None,
splits={},
test_size=0.5,
train_start=None,
train_end=None,
test_start=None,
test_end=None,
time_split_size=2,
description="test",
storage_connector=None,
featurestore_id=12,
data_format="tsv",
location="",
statistics_config=None,
training_dataset_type=training_dataset.TrainingDataset.IN_MEMORY,
extra_filter=None,
seed=1,
transformation_functions={},
)

# Act
split_dfs_returned = spark_engine.write_training_dataset(
training_dataset=td,
query_obj=q,
user_write_options={},
save_mode=training_dataset_engine.TrainingDatasetEngine.OVERWRITE,
read_options={},
feature_view_obj=None,
to_df=True,
)

# Assert
sum_rows = 0
for key in split_dfs_returned:
df_returned = split_dfs_returned[key]
assert set(df_returned.columns) == {"col_0", "col_1", "col_2", "event_time"}
sum_rows += df_returned.count()

assert sum_rows == 2

def test_write_training_dataset_query(self, mocker):
# Arrange
mocker.patch("hsfs.engine.get_type")
Expand Down Expand Up @@ -1967,9 +2085,11 @@ def test_random_split(self, mocker):

# Assert
assert list(result) == ["test_split1", "test_split2"]
sum_rows = 0
for column in list(result):
assert result[column].schema == spark_df.schema
assert not result[column].rdd.isEmpty()
sum_rows += result[column].count()
assert sum_rows == 6

def test_time_series_split(self, mocker):
# Arrange
Expand Down Expand Up @@ -2203,6 +2323,7 @@ def test_write_training_dataset_splits(self, mocker):
featurestore_id=99,
splits={},
id=10,
transformation_functions={},
)

# Act
Expand Down Expand Up @@ -2234,6 +2355,7 @@ def test_write_training_dataset_splits_to_df(self, mocker):
featurestore_id=99,
splits={},
id=10,
transformation_functions={},
)

# Act
Expand Down
6 changes: 3 additions & 3 deletions python/tests/fixtures/training_dataset_fixtures.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"location": "test_location",
"event_start_time": 1646438400000,
"event_end_time": 1646697600000,
"coalesce": "test_coalesce",
"coalesce": true,
"description": "test_description",
"storage_connector": {
"type": "featurestoreJdbcConnectorDTO",
Expand Down Expand Up @@ -43,8 +43,8 @@
"end_time": "test_end_time"
}
],
"validation_size": 2,
"test_size": 3,
"validation_size": 0.0,
"test_size": 0.5,
"train_start": 4,
"train_end": 5,
"validation_start": 6,
Expand Down
6 changes: 3 additions & 3 deletions python/tests/test_training_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def test_from_response_json(self, mocker, backend_fixtures):
assert td.data_format == "hudi"
assert td._start_time == 1646438400000
assert td._end_time == 1646697600000
assert td.validation_size == 2
assert td.test_size == 3
assert td.validation_size == 0.0
assert td.test_size == 0.5
assert td.train_start == 4
assert td.train_end == 5
assert td.validation_start == 6
assert td.validation_end == 7
assert td.test_start == 8
assert td.test_end == 9
assert td.coalesce == "test_coalesce"
assert td.coalesce is True
assert td.seed == 123
assert td.location == "test_location"
assert td._from_query == "test_from_query"
Expand Down