Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ class SparkSession private[sql] (
* Interrupt all operations of this session currently running on the connected server.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
Expand All @@ -627,7 +627,7 @@ class SparkSession private[sql] (
* Interrupt all operations of this session with the given operation tag.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
Expand All @@ -640,7 +640,7 @@ class SparkSession private[sql] (
* Interrupt an operation of this session with the given operationId.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[connect] class ExecuteHolder(
s"SparkConnect_Execute_" +
s"User_${sessionHolder.userId}_" +
s"Session_${sessionHolder.sessionId}_" +
s"Request_${operationId}"
s"Operation_${operationId}"

/**
* Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group
Expand Down Expand Up @@ -118,7 +118,7 @@ private[connect] class ExecuteHolder(
* need to be combined with userId and sessionId.
*/
def tagToSparkJobTag(tag: String): String = {
"SparkConnect_Tag_" +
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
"SparkConnect_Execute_" +
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}"
}
}
7 changes: 7 additions & 0 deletions python/docs/source/reference/pyspark.sql/spark_session.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ Spark Connect Only
SparkSession.addArtifacts
SparkSession.copyFromLocalToFs
SparkSession.client
SparkSession.interruptAll
SparkSession.interruptTag
SparkSession.interruptOperation
SparkSession.addTag
SparkSession.removeTag
SparkSession.getTags
SparkSession.clearTags
104 changes: 91 additions & 13 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

check_dependencies(__name__)

import threading
import logging
import os
import platform
Expand All @@ -41,6 +42,7 @@
List,
Tuple,
Dict,
Set,
NoReturn,
cast,
Callable,
Expand Down Expand Up @@ -574,6 +576,8 @@ def __init__(
the $USER environment. Defining the user ID as part of the connection string
takes precedence.
"""
self.thread_local = threading.local()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, there is difference here. Python's threadlocal isn't inheritable (and there's no such implementation in Python). So we should somehow workaround this ...


# Parse the connection string.
self._builder = (
connection
Expand Down Expand Up @@ -903,9 +907,11 @@ def token(self) -> Optional[str]:
return self._builder._token

def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
req = pb2.ExecutePlanRequest()
req.session_id = self._session_id
req.client_type = self._builder.userAgent
req = pb2.ExecutePlanRequest(
session_id=self._session_id,
client_type=self._builder.userAgent,
tags=list(self.get_tags()),
)
if self._user_id:
req.user_context.user_id = self._user_id
return req
Expand Down Expand Up @@ -1204,12 +1210,22 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult:
except Exception as error:
self._handle_error(error)

def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
def _interrupt_request(
self, interrupt_type: str, id_or_tag: Optional[str] = None
) -> pb2.InterruptRequest:
req = pb2.InterruptRequest()
req.session_id = self._session_id
req.client_type = self._builder.userAgent
if interrupt_type == "all":
req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
elif interrupt_type == "tag":
assert id_or_tag is not None
req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG
req.operation_tag = id_or_tag
elif interrupt_type == "operation":
assert id_or_tag is not None
req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID
req.operation_id = id_or_tag
else:
raise PySparkValueError(
error_class="UNKNOWN_INTERRUPT_TYPE",
Expand All @@ -1221,14 +1237,7 @@ def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
req.user_context.user_id = self._user_id
return req

def interrupt_all(self) -> None:
"""
Call the interrupt RPC of Spark Connect to interrupt all executions in this session.

Returns
-------
None
"""
def interrupt_all(self) -> Optional[List[str]]:
req = self._interrupt_request("all")
try:
for attempt in Retrying(
Expand All @@ -1241,11 +1250,80 @@ def interrupt_all(self) -> None:
"Received incorrect session identifier for request:"
f"{resp.session_id} != {self._session_id}"
)
return
return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception handling.")
except Exception as error:
self._handle_error(error)

def interrupt_tag(self, tag: str) -> Optional[List[str]]:
req = self._interrupt_request("tag", tag)
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
if resp.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request:"
f"{resp.session_id} != {self._session_id}"
)
return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception handling.")
except Exception as error:
self._handle_error(error)

def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
req = self._interrupt_request("operation", op_id)
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
if resp.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request:"
f"{resp.session_id} != {self._session_id}"
)
return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception handling.")
except Exception as error:
self._handle_error(error)

def add_tag(self, tag: str) -> None:
self._throw_if_invalid_tag(tag)
if not hasattr(self.thread_local, "tags"):
self.thread_local.tags = set()
self.thread_local.tags.add(tag)

def remove_tag(self, tag: str) -> None:
self._throw_if_invalid_tag(tag)
if not hasattr(self.thread_local, "tags"):
self.thread_local.tags = set()
self.thread_local.tags.remove(tag)

def get_tags(self) -> Set[str]:
if not hasattr(self.thread_local, "tags"):
self.thread_local.tags = set()
return self.thread_local.tags

def clear_tags(self) -> None:
self.thread_local.tags = set()

def _throw_if_invalid_tag(self, tag: str) -> None:
"""
Validate if a tag for ExecutePlanRequest.tags is valid. Throw ``ValueError`` if
not.
"""
spark_job_tags_sep = ","
if tag is None:
raise ValueError("Spark Connect tag cannot be null.")
if spark_job_tags_sep in tag:
raise ValueError(f"Spark Connect tag cannot contain '{spark_job_tags_sep}'.")
if len(tag) == 0:
raise ValueError("Spark Connect tag cannot be an empty string.")

def _handle_error(self, error: Exception) -> NoReturn:
"""
Handle errors that occur during RPC calls.
Expand Down
43 changes: 41 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Dict,
List,
Tuple,
Set,
cast,
overload,
Iterable,
Expand Down Expand Up @@ -550,8 +551,46 @@ def __del__(self) -> None:
except Exception:
pass

def interrupt_all(self) -> None:
self.client.interrupt_all()
def interruptAll(self) -> List[str]:
op_ids = self.client.interrupt_all()
assert op_ids is not None
return op_ids

interruptAll.__doc__ = PySparkSession.interruptAll.__doc__

def interruptTag(self, tag: str) -> List[str]:
op_ids = self.client.interrupt_tag(tag)
assert op_ids is not None
return op_ids

interruptTag.__doc__ = PySparkSession.interruptTag.__doc__

def interruptOperation(self, op_id: str) -> List[str]:
op_ids = self.client.interrupt_operation(op_id)
assert op_ids is not None
return op_ids

interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__

def addTag(self, tag: str) -> None:
self.client.add_tag(tag)

addTag.__doc__ = PySparkSession.addTag.__doc__

def removeTag(self, tag: str) -> None:
self.client.remove_tag(tag)

removeTag.__doc__ = PySparkSession.removeTag.__doc__

def getTags(self) -> Set[str]:
return self.client.get_tags()

getTags.__doc__ = PySparkSession.getTags.__doc__

def clearTags(self) -> None:
return self.client.clear_tags()

clearTags.__doc__ = PySparkSession.clearTags.__doc__

def stop(self) -> None:
# Stopping the session will only close the connection to the current session (and
Expand Down
Loading