Skip to content

Commit 5ee462d

Browse files
committed
[SPARK-44509][PYTHON][CONNECT] Add job cancellation API set in Spark Connect Python client
### What changes were proposed in this pull request? This PR proposes the Python implementations for #42009. ### Why are the changes needed? For the feature parity, and better control of query cancelation in Spark Connect ### Does this PR introduce _any_ user-facing change? Yes. New Apis in Spark Connect Python client: ``` SparkSession.addTag SparkSession.removeTag SparkSession.getTags SparkSession.clearTags SparkSession.interruptTag SparkSession.interruptOperation ``` ### How was this patch tested? Unittests were added, and manually tested too. Closes #42120 from HyukjinKwon/SPARK-44509. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 2f7a9a1 commit 5ee462d

File tree

7 files changed

+377
-22
lines changed

7 files changed

+377
-22
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ class SparkSession private[sql] (
615615
* Interrupt all operations of this session currently running on the connected server.
616616
*
617617
* @return
618-
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
618+
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
619619
* operation finishing just as it is interrupted.
620620
*
621621
* @since 3.5.0
@@ -628,7 +628,7 @@ class SparkSession private[sql] (
628628
* Interrupt all operations of this session with the given operation tag.
629629
*
630630
* @return
631-
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
631+
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
632632
* operation finishing just as it is interrupted.
633633
*
634634
* @since 3.5.0
@@ -641,7 +641,7 @@ class SparkSession private[sql] (
641641
* Interrupt an operation of this session with the given operationId.
642642
*
643643
* @return
644-
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
644+
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
645645
* operation finishing just as it is interrupted.
646646
*
647647
* @since 3.5.0

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ private[connect] class ExecuteHolder(
4242
s"SparkConnect_Execute_" +
4343
s"User_${sessionHolder.userId}_" +
4444
s"Session_${sessionHolder.sessionId}_" +
45-
s"Request_${operationId}"
45+
s"Operation_${operationId}"
4646

4747
/**
4848
* Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group
@@ -118,7 +118,7 @@ private[connect] class ExecuteHolder(
118118
* need to be combined with userId and sessionId.
119119
*/
120120
def tagToSparkJobTag(tag: String): String = {
121-
"SparkConnect_Tag_" +
122-
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
121+
"SparkConnect_Execute_" +
122+
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}"
123123
}
124124
}

python/docs/source/reference/pyspark.sql/spark_session.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,10 @@ Spark Connect Only
6363
SparkSession.addArtifacts
6464
SparkSession.copyFromLocalToFs
6565
SparkSession.client
66+
SparkSession.interruptAll
67+
SparkSession.interruptTag
68+
SparkSession.interruptOperation
69+
SparkSession.addTag
70+
SparkSession.removeTag
71+
SparkSession.getTags
72+
SparkSession.clearTags

python/pyspark/sql/connect/client/core.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
check_dependencies(__name__)
2525

26+
import threading
2627
import logging
2728
import os
2829
import platform
@@ -41,6 +42,7 @@
4142
List,
4243
Tuple,
4344
Dict,
45+
Set,
4446
NoReturn,
4547
cast,
4648
Callable,
@@ -574,6 +576,8 @@ def __init__(
574576
the $USER environment. Defining the user ID as part of the connection string
575577
takes precedence.
576578
"""
579+
self.thread_local = threading.local()
580+
577581
# Parse the connection string.
578582
self._builder = (
579583
connection
@@ -922,9 +926,11 @@ def token(self) -> Optional[str]:
922926
return self._builder._token
923927

924928
def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
925-
req = pb2.ExecutePlanRequest()
926-
req.session_id = self._session_id
927-
req.client_type = self._builder.userAgent
929+
req = pb2.ExecutePlanRequest(
930+
session_id=self._session_id,
931+
client_type=self._builder.userAgent,
932+
tags=list(self.get_tags()),
933+
)
928934
if self._user_id:
929935
req.user_context.user_id = self._user_id
930936
return req
@@ -1243,12 +1249,22 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult:
12431249
except Exception as error:
12441250
self._handle_error(error)
12451251

1246-
def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
1252+
def _interrupt_request(
1253+
self, interrupt_type: str, id_or_tag: Optional[str] = None
1254+
) -> pb2.InterruptRequest:
12471255
req = pb2.InterruptRequest()
12481256
req.session_id = self._session_id
12491257
req.client_type = self._builder.userAgent
12501258
if interrupt_type == "all":
12511259
req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
1260+
elif interrupt_type == "tag":
1261+
assert id_or_tag is not None
1262+
req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG
1263+
req.operation_tag = id_or_tag
1264+
elif interrupt_type == "operation":
1265+
assert id_or_tag is not None
1266+
req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID
1267+
req.operation_id = id_or_tag
12521268
else:
12531269
raise PySparkValueError(
12541270
error_class="UNKNOWN_INTERRUPT_TYPE",
@@ -1260,14 +1276,7 @@ def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
12601276
req.user_context.user_id = self._user_id
12611277
return req
12621278

1263-
def interrupt_all(self) -> None:
1264-
"""
1265-
Call the interrupt RPC of Spark Connect to interrupt all executions in this session.
1266-
1267-
Returns
1268-
-------
1269-
None
1270-
"""
1279+
def interrupt_all(self) -> Optional[List[str]]:
12711280
req = self._interrupt_request("all")
12721281
try:
12731282
for attempt in Retrying(
@@ -1280,11 +1289,80 @@ def interrupt_all(self) -> None:
12801289
"Received incorrect session identifier for request:"
12811290
f"{resp.session_id} != {self._session_id}"
12821291
)
1283-
return
1292+
return list(resp.interrupted_ids)
1293+
raise SparkConnectException("Invalid state during retry exception handling.")
1294+
except Exception as error:
1295+
self._handle_error(error)
1296+
1297+
def interrupt_tag(self, tag: str) -> Optional[List[str]]:
1298+
req = self._interrupt_request("tag", tag)
1299+
try:
1300+
for attempt in Retrying(
1301+
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
1302+
):
1303+
with attempt:
1304+
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
1305+
if resp.session_id != self._session_id:
1306+
raise SparkConnectException(
1307+
"Received incorrect session identifier for request:"
1308+
f"{resp.session_id} != {self._session_id}"
1309+
)
1310+
return list(resp.interrupted_ids)
12841311
raise SparkConnectException("Invalid state during retry exception handling.")
12851312
except Exception as error:
12861313
self._handle_error(error)
12871314

1315+
def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
1316+
req = self._interrupt_request("operation", op_id)
1317+
try:
1318+
for attempt in Retrying(
1319+
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
1320+
):
1321+
with attempt:
1322+
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
1323+
if resp.session_id != self._session_id:
1324+
raise SparkConnectException(
1325+
"Received incorrect session identifier for request:"
1326+
f"{resp.session_id} != {self._session_id}"
1327+
)
1328+
return list(resp.interrupted_ids)
1329+
raise SparkConnectException("Invalid state during retry exception handling.")
1330+
except Exception as error:
1331+
self._handle_error(error)
1332+
1333+
def add_tag(self, tag: str) -> None:
1334+
self._throw_if_invalid_tag(tag)
1335+
if not hasattr(self.thread_local, "tags"):
1336+
self.thread_local.tags = set()
1337+
self.thread_local.tags.add(tag)
1338+
1339+
def remove_tag(self, tag: str) -> None:
1340+
self._throw_if_invalid_tag(tag)
1341+
if not hasattr(self.thread_local, "tags"):
1342+
self.thread_local.tags = set()
1343+
self.thread_local.tags.remove(tag)
1344+
1345+
def get_tags(self) -> Set[str]:
1346+
if not hasattr(self.thread_local, "tags"):
1347+
self.thread_local.tags = set()
1348+
return self.thread_local.tags
1349+
1350+
def clear_tags(self) -> None:
1351+
self.thread_local.tags = set()
1352+
1353+
def _throw_if_invalid_tag(self, tag: str) -> None:
1354+
"""
1355+
Validate if a tag for ExecutePlanRequest.tags is valid. Throw ``ValueError`` if
1356+
not.
1357+
"""
1358+
spark_job_tags_sep = ","
1359+
if tag is None:
1360+
raise ValueError("Spark Connect tag cannot be null.")
1361+
if spark_job_tags_sep in tag:
1362+
raise ValueError(f"Spark Connect tag cannot contain '{spark_job_tags_sep}'.")
1363+
if len(tag) == 0:
1364+
raise ValueError("Spark Connect tag cannot be an empty string.")
1365+
12881366
def _handle_error(self, error: Exception) -> NoReturn:
12891367
"""
12901368
Handle errors that occur during RPC calls.

python/pyspark/sql/connect/session.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Dict,
3232
List,
3333
Tuple,
34+
Set,
3435
cast,
3536
overload,
3637
Iterable,
@@ -550,8 +551,46 @@ def __del__(self) -> None:
550551
except Exception:
551552
pass
552553

553-
def interrupt_all(self) -> None:
554-
self.client.interrupt_all()
554+
def interruptAll(self) -> List[str]:
555+
op_ids = self.client.interrupt_all()
556+
assert op_ids is not None
557+
return op_ids
558+
559+
interruptAll.__doc__ = PySparkSession.interruptAll.__doc__
560+
561+
def interruptTag(self, tag: str) -> List[str]:
562+
op_ids = self.client.interrupt_tag(tag)
563+
assert op_ids is not None
564+
return op_ids
565+
566+
interruptTag.__doc__ = PySparkSession.interruptTag.__doc__
567+
568+
def interruptOperation(self, op_id: str) -> List[str]:
569+
op_ids = self.client.interrupt_operation(op_id)
570+
assert op_ids is not None
571+
return op_ids
572+
573+
interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__
574+
575+
def addTag(self, tag: str) -> None:
576+
self.client.add_tag(tag)
577+
578+
addTag.__doc__ = PySparkSession.addTag.__doc__
579+
580+
def removeTag(self, tag: str) -> None:
581+
self.client.remove_tag(tag)
582+
583+
removeTag.__doc__ = PySparkSession.removeTag.__doc__
584+
585+
def getTags(self) -> Set[str]:
586+
return self.client.get_tags()
587+
588+
getTags.__doc__ = PySparkSession.getTags.__doc__
589+
590+
def clearTags(self) -> None:
591+
return self.client.clear_tags()
592+
593+
clearTags.__doc__ = PySparkSession.clearTags.__doc__
555594

556595
def stop(self) -> None:
557596
# Stopping the session will only close the connection to the current session (and

0 commit comments

Comments
 (0)