Skip to content

Commit 06aafb1

Browse files
committed
[SPARK-48258][PYTHON][CONNECT][FOLLOW-UP] Bind relation ID to the plan instead of DataFrame
### What changes were proposed in this pull request? This PR addresses #46683 (comment) comment within Python, by using ID at the plan instead of DataFrame itself. ### Why are the changes needed? Because the DataFrame holds the relation ID, if DataFrame B are derived from DataFrame A, and DataFrame A is garbage-collected, then the cache might not exist anymore. See the example below: ```python df = spark.range(1).localCheckpoint() df2 = df.repartition(10) del df df2.collect() ``` ``` pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.sql.connect.common.InvalidPlanInput) No DataFrame with id a4efa660-897c-4500-bd4e-bd57cd0263d2 is found in the session cd4764b4-90a9-4249-9140-12a6e4a98cd3 ``` ### Does this PR introduce _any_ user-facing change? No, the main change has not been released out yet. ### How was this patch tested? Manually tested, and added a unittest. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46694 from HyukjinKwon/SPARK-48258-followup. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 6be3560 commit 06aafb1

File tree

5 files changed

+97
-55
lines changed

5 files changed

+97
-55
lines changed

python/pyspark/sql/connect/conversion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "Dat
577577
from pyspark.sql.connect.session import SparkSession
578578
import pyspark.sql.connect.plan as plan
579579

580+
session = SparkSession.active()
580581
return DataFrame(
581-
plan=plan.CachedRemoteRelation(relation.relation_id),
582-
session=SparkSession.active(),
582+
plan=plan.CachedRemoteRelation(relation.relation_id, session),
583+
session=session,
583584
)

python/pyspark/sql/connect/dataframe.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717

1818
# mypy: disable-error-code="override"
19-
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
2019
from pyspark.errors.exceptions.base import (
2120
SessionNotSameException,
2221
PySparkIndexError,
@@ -138,41 +137,6 @@ def __init__(
138137
# by __repr__ and _repr_html_ while eager evaluation opens.
139138
self._support_repr_html = False
140139
self._cached_schema: Optional[StructType] = None
141-
self._cached_remote_relation_id: Optional[str] = None
142-
143-
def __del__(self) -> None:
144-
# If session is already closed, all cached DataFrame should be released.
145-
if not self._session.client.is_closed and self._cached_remote_relation_id is not None:
146-
try:
147-
command = plan.RemoveRemoteCachedRelation(
148-
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
149-
).command(session=self._session.client)
150-
req = self._session.client._execute_plan_request_with_metadata()
151-
if self._session.client._user_id:
152-
req.user_context.user_id = self._session.client._user_id
153-
req.plan.command.CopyFrom(command)
154-
155-
for attempt in self._session.client._retrying():
156-
with attempt:
157-
# !!HACK ALERT!!
158-
# unary_stream does not work on Python's exit for an unknown reasons
159-
# Therefore, here we open unary_unary channel instead.
160-
# See also :class:`SparkConnectServiceStub`.
161-
request_serializer = (
162-
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
163-
)
164-
response_deserializer = (
165-
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
166-
)
167-
channel = self._session.client._channel.unary_unary(
168-
"/spark.connect.SparkConnectService/ExecutePlan",
169-
request_serializer=request_serializer,
170-
response_deserializer=response_deserializer,
171-
)
172-
metadata = self._session.client._builder.metadata()
173-
channel(req, metadata=metadata) # type: ignore[arg-type]
174-
except Exception as e:
175-
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")
176140

177141
def __reduce__(self) -> Tuple:
178142
"""
@@ -2137,7 +2101,6 @@ def checkpoint(self, eager: bool = True) -> "DataFrame":
21372101
assert "checkpoint_command_result" in properties
21382102
checkpointed = properties["checkpoint_command_result"]
21392103
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
2140-
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
21412104
return checkpointed
21422105

21432106
def localCheckpoint(self, eager: bool = True) -> "DataFrame":
@@ -2146,7 +2109,6 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame":
21462109
assert "checkpoint_command_result" in properties
21472110
checkpointed = properties["checkpoint_command_result"]
21482111
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
2149-
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
21502112
return checkpointed
21512113

21522114
if not is_remote_only():

python/pyspark/sql/connect/plan.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import pickle
4141
from threading import Lock
4242
from inspect import signature, isclass
43+
import warnings
4344

4445
import pyarrow as pa
4546

@@ -49,6 +50,7 @@
4950

5051
import pyspark.sql.connect.proto as proto
5152
from pyspark.sql.column import Column
53+
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
5254
from pyspark.sql.connect.conversion import storage_level_to_proto
5355
from pyspark.sql.connect.expressions import Expression
5456
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
@@ -62,6 +64,7 @@
6264
from pyspark.sql.connect.client import SparkConnectClient
6365
from pyspark.sql.connect.udf import UserDefinedFunction
6466
from pyspark.sql.connect.observation import Observation
67+
from pyspark.sql.connect.session import SparkSession
6568

6669

6770
class LogicalPlan:
@@ -547,14 +550,49 @@ class CachedRemoteRelation(LogicalPlan):
547550
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
548551
cached on the server with a given id."""
549552

550-
def __init__(self, relationId: str):
553+
def __init__(self, relation_id: str, spark_session: "SparkSession"):
551554
super().__init__(None)
552-
self._relationId = relationId
553-
554-
def plan(self, session: "SparkConnectClient") -> proto.Relation:
555-
plan = self._create_proto_relation()
556-
plan.cached_remote_relation.relation_id = self._relationId
557-
return plan
555+
self._relation_id = relation_id
556+
# Needs to hold the session to make a request itself.
557+
self._spark_session = spark_session
558+
559+
def plan(self, session: "SparkConnectClient") -> proto.Relation:
560+
plan = self._create_proto_relation()
561+
plan.cached_remote_relation.relation_id = self._relation_id
562+
return plan
563+
564+
def __del__(self) -> None:
565+
session = self._spark_session
566+
# If session is already closed, all cached DataFrame should be released.
567+
if session is not None and not session.client.is_closed and self._relation_id is not None:
568+
try:
569+
command = RemoveRemoteCachedRelation(self).command(session=session.client)
570+
req = session.client._execute_plan_request_with_metadata()
571+
if session.client._user_id:
572+
req.user_context.user_id = session.client._user_id
573+
req.plan.command.CopyFrom(command)
574+
575+
for attempt in session.client._retrying():
576+
with attempt:
577+
# !!HACK ALERT!!
578+
# unary_stream does not work on Python's exit for an unknown reasons
579+
# Therefore, here we open unary_unary channel instead.
580+
# See also :class:`SparkConnectServiceStub`.
581+
request_serializer = (
582+
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
583+
)
584+
response_deserializer = (
585+
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
586+
)
587+
channel = session.client._channel.unary_unary(
588+
"/spark.connect.SparkConnectService/ExecutePlan",
589+
request_serializer=request_serializer,
590+
response_deserializer=response_deserializer,
591+
)
592+
metadata = session.client._builder.metadata()
593+
channel(req, metadata=metadata) # type: ignore[arg-type]
594+
except Exception as e:
595+
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")
558596

559597

560598
class Hint(LogicalPlan):
@@ -1792,7 +1830,7 @@ def __init__(self, relation: CachedRemoteRelation) -> None:
17921830

17931831
def command(self, session: "SparkConnectClient") -> proto.Command:
17941832
plan = self._create_proto_relation()
1795-
plan.cached_remote_relation.relation_id = self._relation._relationId
1833+
plan.cached_remote_relation.relation_id = self._relation._relation_id
17961834
cmd = proto.Command()
17971835
cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
17981836
return cmd

python/pyspark/sql/connect/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame":
926926
This is used in ForeachBatch() runner, where the remote DataFrame refers to the
927927
output of a micro batch.
928928
"""
929-
return DataFrame(CachedRemoteRelation(remote_id), self)
929+
return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), self)
930930

931931
@staticmethod
932932
def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:

python/pyspark/sql/tests/connect/test_connect_basic.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
#
1717

1818
import os
19+
import gc
1920
import unittest
2021
import shutil
2122
import tempfile
22-
import time
2323

2424
from pyspark.util import is_remote_only
2525
from pyspark.errors import PySparkTypeError, PySparkValueError
@@ -34,6 +34,7 @@
3434
ArrayType,
3535
Row,
3636
)
37+
from pyspark.testing.utils import eventually
3738
from pyspark.testing.sqlutils import SQLTestUtils
3839
from pyspark.testing.connectutils import (
3940
should_test_connect,
@@ -1379,8 +1380,8 @@ def test_garbage_collection_checkpoint(self):
13791380
# SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state
13801381
# in Spark Connect server
13811382
df = self.connect.range(10).localCheckpoint()
1382-
self.assertIsNotNone(df._cached_remote_relation_id)
1383-
cached_remote_relation_id = df._cached_remote_relation_id
1383+
self.assertIsNotNone(df._plan._relation_id)
1384+
cached_remote_relation_id = df._plan._relation_id
13841385

13851386
jvm = self.spark._jvm
13861387
session_holder = getattr(
@@ -1397,14 +1398,54 @@ def test_garbage_collection_checkpoint(self):
13971398
)
13981399

13991400
del df
1401+
gc.collect()
14001402

1401-
time.sleep(3) # Make sure removing is triggered, and executed in the server.
1403+
def condition():
1404+
# Check the state was removed up on garbage-collection.
1405+
self.assertIsNone(
1406+
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
1407+
)
1408+
1409+
eventually(catch_assertions=True)(condition)()
1410+
1411+
def test_garbage_collection_derived_checkpoint(self):
1412+
# SPARK-48258: Should keep the cached remote relation when derived DataFrames exist
1413+
df = self.connect.range(10).localCheckpoint()
1414+
self.assertIsNotNone(df._plan._relation_id)
1415+
derived = df.repartition(10)
1416+
cached_remote_relation_id = df._plan._relation_id
14021417

1403-
# Check the state was removed up on garbage-collection.
1404-
self.assertIsNone(
1418+
jvm = self.spark._jvm
1419+
session_holder = getattr(
1420+
getattr(
1421+
jvm.org.apache.spark.sql.connect.service,
1422+
"SparkConnectService$",
1423+
),
1424+
"MODULE$",
1425+
).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id)
1426+
1427+
# Check the state exists.
1428+
self.assertIsNotNone(
14051429
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
14061430
)
14071431

1432+
del df
1433+
gc.collect()
1434+
1435+
def condition():
1436+
self.assertIsNone(
1437+
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
1438+
)
1439+
1440+
# Should not remove the cache
1441+
with self.assertRaises(AssertionError):
1442+
eventually(catch_assertions=True, timeout=5)(condition)()
1443+
1444+
del derived
1445+
gc.collect()
1446+
1447+
eventually(catch_assertions=True)(condition)()
1448+
14081449

14091450
if __name__ == "__main__":
14101451
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401

0 commit comments

Comments
 (0)