diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index b37e3884038b..0f2f074026e4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -614,7 +614,7 @@ class SparkSession private[sql] ( * Interrupt all operations of this session currently running on the connected server. * * @return - * sequence of operationIds of interrupted operations. Note: there is still a possiblility of + * sequence of operationIds of interrupted operations. Note: there is still a possibility of * operation finishing just as it is interrupted. * * @since 3.5.0 @@ -627,7 +627,7 @@ class SparkSession private[sql] ( * Interrupt all operations of this session with the given operation tag. * * @return - * sequence of operationIds of interrupted operations. Note: there is still a possiblility of + * sequence of operationIds of interrupted operations. Note: there is still a possibility of * operation finishing just as it is interrupted. * * @since 3.5.0 @@ -640,7 +640,7 @@ class SparkSession private[sql] ( * Interrupt an operation of this session with the given operationId. * * @return - * sequence of operationIds of interrupted operations. Note: there is still a possiblility of + * sequence of operationIds of interrupted operations. Note: there is still a possibility of * operation finishing just as it is interrupted. * * @since 3.5.0 diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 74530ad032f1..36c96b2617fb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -42,7 +42,7 @@ private[connect] class ExecuteHolder( s"SparkConnect_Execute_" + s"User_${sessionHolder.userId}_" + s"Session_${sessionHolder.sessionId}_" + - s"Request_${operationId}" + s"Operation_${operationId}" /** * Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group @@ -118,7 +118,7 @@ private[connect] class ExecuteHolder( * need to be combined with userId and sessionId. */ def tagToSparkJobTag(tag: String): String = { - "SparkConnect_Tag_" + - s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}" + "SparkConnect_Execute_" + + s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}" } } diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index 6a46db7576b7..c16ca4f162f5 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -63,3 +63,10 @@ Spark Connect Only SparkSession.addArtifacts SparkSession.copyFromLocalToFs SparkSession.client + SparkSession.interruptAll + SparkSession.interruptTag + SparkSession.interruptOperation + SparkSession.addTag + SparkSession.removeTag + SparkSession.getTags + SparkSession.clearTags diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 167a713a0a35..42de0a8ef9f9 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -23,6 +23,7 @@ check_dependencies(__name__) +import threading import logging import os import platform @@ -41,6 +42,7 @@ List, Tuple, Dict, + Set, NoReturn, cast, Callable, @@ -574,6 +576,8 @@ def __init__( the $USER environment. Defining the user ID as part of the connection string takes precedence. """ + self.thread_local = threading.local() + # Parse the connection string. self._builder = ( connection @@ -903,9 +907,11 @@ def token(self) -> Optional[str]: return self._builder._token def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest: - req = pb2.ExecutePlanRequest() - req.session_id = self._session_id - req.client_type = self._builder.userAgent + req = pb2.ExecutePlanRequest( + session_id=self._session_id, + client_type=self._builder.userAgent, + tags=list(self.get_tags()), + ) if self._user_id: req.user_context.user_id = self._user_id return req @@ -1204,12 +1210,22 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult: except Exception as error: self._handle_error(error) - def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest: + def _interrupt_request( + self, interrupt_type: str, id_or_tag: Optional[str] = None + ) -> pb2.InterruptRequest: req = pb2.InterruptRequest() req.session_id = self._session_id req.client_type = self._builder.userAgent if interrupt_type == "all": req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL + elif interrupt_type == "tag": + assert id_or_tag is not None + req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG + req.operation_tag = id_or_tag + elif interrupt_type == "operation": + assert id_or_tag is not None + req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID + req.operation_id = id_or_tag else: raise PySparkValueError( error_class="UNKNOWN_INTERRUPT_TYPE", @@ -1221,14 +1237,7 @@ def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest: req.user_context.user_id = self._user_id return req - def interrupt_all(self) -> None: - """ - Call the interrupt RPC of Spark Connect to interrupt all executions in this session. - - Returns - ------- - None - """ + def interrupt_all(self) -> Optional[List[str]]: req = self._interrupt_request("all") try: for attempt in Retrying( @@ -1241,11 +1250,80 @@ def interrupt_all(self) -> None: "Received incorrect session identifier for request:" f"{resp.session_id} != {self._session_id}" ) - return + return list(resp.interrupted_ids) + raise SparkConnectException("Invalid state during retry exception handling.") + except Exception as error: + self._handle_error(error) + + def interrupt_tag(self, tag: str) -> Optional[List[str]]: + req = self._interrupt_request("tag", tag) + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + if resp.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request:" + f"{resp.session_id} != {self._session_id}" + ) + return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: self._handle_error(error) + def interrupt_operation(self, op_id: str) -> Optional[List[str]]: + req = self._interrupt_request("operation", op_id) + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + if resp.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request:" + f"{resp.session_id} != {self._session_id}" + ) + return list(resp.interrupted_ids) + raise SparkConnectException("Invalid state during retry exception handling.") + except Exception as error: + self._handle_error(error) + + def add_tag(self, tag: str) -> None: + self._throw_if_invalid_tag(tag) + if not hasattr(self.thread_local, "tags"): + self.thread_local.tags = set() + self.thread_local.tags.add(tag) + + def remove_tag(self, tag: str) -> None: + self._throw_if_invalid_tag(tag) + if not hasattr(self.thread_local, "tags"): + self.thread_local.tags = set() + self.thread_local.tags.remove(tag) + + def get_tags(self) -> Set[str]: + if not hasattr(self.thread_local, "tags"): + self.thread_local.tags = set() + return self.thread_local.tags + + def clear_tags(self) -> None: + self.thread_local.tags = set() + + def _throw_if_invalid_tag(self, tag: str) -> None: + """ + Validate if a tag for ExecutePlanRequest.tags is valid. Throw ``ValueError`` if + not. + """ + spark_job_tags_sep = "," + if tag is None: + raise ValueError("Spark Connect tag cannot be null.") + if spark_job_tags_sep in tag: + raise ValueError(f"Spark Connect tag cannot contain '{spark_job_tags_sep}'.") + if len(tag) == 0: + raise ValueError("Spark Connect tag cannot be an empty string.") + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 37a5bdd9f9fd..7a5f56777709 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -31,6 +31,7 @@ Dict, List, Tuple, + Set, cast, overload, Iterable, @@ -550,8 +551,46 @@ def __del__(self) -> None: except Exception: pass - def interrupt_all(self) -> None: - self.client.interrupt_all() + def interruptAll(self) -> List[str]: + op_ids = self.client.interrupt_all() + assert op_ids is not None + return op_ids + + interruptAll.__doc__ = PySparkSession.interruptAll.__doc__ + + def interruptTag(self, tag: str) -> List[str]: + op_ids = self.client.interrupt_tag(tag) + assert op_ids is not None + return op_ids + + interruptTag.__doc__ = PySparkSession.interruptTag.__doc__ + + def interruptOperation(self, op_id: str) -> List[str]: + op_ids = self.client.interrupt_operation(op_id) + assert op_ids is not None + return op_ids + + interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__ + + def addTag(self, tag: str) -> None: + self.client.add_tag(tag) + + addTag.__doc__ = PySparkSession.addTag.__doc__ + + def removeTag(self, tag: str) -> None: + self.client.remove_tag(tag) + + removeTag.__doc__ = PySparkSession.removeTag.__doc__ + + def getTags(self) -> Set[str]: + return self.client.get_tags() + + getTags.__doc__ = PySparkSession.getTags.__doc__ + + def clearTags(self) -> None: + return self.client.clear_tags() + + clearTags.__doc__ = PySparkSession.clearTags.__doc__ def stop(self) -> None: # Stopping the session will only close the connection to the current session (and diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 00a0047dfd15..834b0307238a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -31,6 +31,7 @@ Tuple, Type, Union, + Set, cast, no_type_check, overload, @@ -1858,6 +1859,127 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: "however, the current Spark session does not use Spark Connect." ) + def interruptAll(self) -> List[str]: + """ + Interrupt all operations of this session currently running on the connected server. + + .. versionadded:: 3.5.0 + + Returns + ------- + list of str + List of operationIds of interrupted operations. + + Notes + ----- + There is still a possibility of operation finishing just as it is interrupted. + """ + raise RuntimeError( + "SparkSession.interruptAll is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + + def interruptTag(self, tag: str) -> List[str]: + """ + Interrupt all operations of this session with the given operation tag. + + .. versionadded:: 3.5.0 + + Returns + ------- + list of str + List of operationIds of interrupted operations. + + Notes + ----- + There is still a possibility of operation finishing just as it is interrupted. + """ + raise RuntimeError( + "SparkSession.interruptTag is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + + def interruptOperation(self, op_id: str) -> List[str]: + """ + Interrupt an operation of this session with the given operationId. + + .. versionadded:: 3.5.0 + + Returns + ------- + list of str + List of operationIds of interrupted operations. + + Notes + ----- + There is still a possibility of operation finishing just as it is interrupted. + """ + raise RuntimeError( + "SparkSession.interruptOperation is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + + def addTag(self, tag: str) -> None: + """ + Add a tag to be assigned to all the operations started by this thread in this session. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + tag : list of str + The tag to be added. Cannot contain ',' (comma) character or be an empty string. + """ + raise RuntimeError( + "SparkSession.addTag is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + + def removeTag(self, tag: str) -> None: + """ + Remove a tag previously added to be assigned to all the operations started by this thread in + this session. Noop if such a tag was not added earlier. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + tag : list of str + The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + """ + raise RuntimeError( + "SparkSession.removeTag is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + + def getTags(self) -> Set[str]: + """ + Get the tags that are currently set to be assigned to all the operations started by this + thread. + + .. versionadded:: 3.5.0 + + Returns + ------- + set of str + Set of tags of interrupted operations. + """ + raise RuntimeError( + "SparkSession.getTags is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + + def clearTags(self) -> None: + """ + Clear the current thread's operation tags. + + .. versionadded:: 3.5.0 + """ + raise RuntimeError( + "SparkSession.clearTags is only supported with Spark Connect; " + "however, the current Spark session does not use Spark Connect." + ) + def _test() -> None: import os diff --git a/python/pyspark/sql/tests/connect/test_session.py b/python/pyspark/sql/tests/connect/test_session.py index 2f14eeddc1e4..0482f119d63b 100644 --- a/python/pyspark/sql/tests/connect/test_session.py +++ b/python/pyspark/sql/tests/connect/test_session.py @@ -14,12 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import threading +import time import unittest from typing import Optional from pyspark.sql.connect.client import ChannelBuilder from pyspark.sql.connect.session import SparkSession as RemoteSparkSession +from pyspark.testing.connectutils import should_test_connect + +if should_test_connect: + from pyspark.testing.connectutils import ReusedConnectTestCase class CustomChannelBuilder(ChannelBuilder): @@ -70,3 +75,107 @@ def test_session_create_sets_active_session(self): self.assertIs(session, session2) session.stop() + + +class ArrowParityTests(ReusedConnectTestCase): + def test_tags(self): + self.spark.clearTags() + self.spark.addTag("a") + self.assertEqual(self.spark.getTags(), {"a"}) + self.spark.addTag("b") + self.spark.removeTag("a") + self.assertEqual(self.spark.getTags(), {"b"}) + self.spark.addTag("c") + self.spark.clearTags() + self.assertEqual(self.spark.getTags(), set()) + self.spark.clearTags() + + def test_interrupt_tag(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: self.spark.addTag(job_group), + lambda job_group: self.spark.interruptTag(job_group), + thread_ids, + [i for i in thread_ids if i % 2 == 0], + [i for i in thread_ids if i % 2 != 0], + ) + self.spark.clearTags() + + def test_interrupt_all(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: None, + lambda job_group: self.spark.interruptAll(), + thread_ids, + thread_ids, + [], + ) + self.spark.clearTags() + + def check_job_cancellation( + self, setter, canceller, thread_ids, thread_ids_to_cancel, thread_ids_to_run + ): + + job_id_a = "job_ids_to_cancel" + job_id_b = "job_ids_to_run" + threads = [] + + # A list which records whether job is cancelled. + # The index of the array is the thread index which job run in. + is_job_cancelled = [False for _ in thread_ids] + + def run_job(job_id, index): + """ + Executes a job with the group ``job_group``. Each job waits for 3 seconds + and then exits. + """ + try: + setter(job_id) + + def func(itr): + for pdf in itr: + time.sleep(pdf._1.iloc[0]) + yield pdf + + self.spark.createDataFrame([[20]]).repartition(1).mapInPandas( + func, schema="_1 LONG" + ).collect() + is_job_cancelled[index] = False + except Exception: + # Assume that exception means job cancellation. + is_job_cancelled[index] = True + + # Test if job succeeded when not cancelled. + run_job(job_id_a, 0) + self.assertFalse(is_job_cancelled[0]) + self.spark.clearTags() + + # Run jobs + for i in thread_ids_to_cancel: + t = threading.Thread(target=run_job, args=(job_id_a, i)) + t.start() + threads.append(t) + + for i in thread_ids_to_run: + t = threading.Thread(target=run_job, args=(job_id_b, i)) + t.start() + threads.append(t) + + # Wait to make sure all jobs are executed. + time.sleep(10) + # And then, cancel one job group. + canceller(job_id_a) + + # Wait until all threads launching jobs are finished. + for t in threads: + t.join() + + for i in thread_ids_to_cancel: + self.assertTrue( + is_job_cancelled[i], "Thread {i}: Job in group A was not cancelled.".format(i=i) + ) + + for i in thread_ids_to_run: + self.assertFalse( + is_job_cancelled[i], "Thread {i}: Job in group B did not succeeded.".format(i=i) + )