-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40453][SPARK-41715][CONNECT] Take super class into account when throwing an exception #39947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,9 @@ package org.apache.spark.sql.connect.service | |
|
|
||
| import java.util.concurrent.TimeUnit | ||
|
|
||
| import scala.annotation.tailrec | ||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import com.google.common.base.Ticker | ||
|
|
@@ -31,12 +33,16 @@ import io.grpc.netty.NettyServerBuilder | |
| import io.grpc.protobuf.StatusProto | ||
| import io.grpc.protobuf.services.ProtoReflectionService | ||
| import io.grpc.stub.StreamObserver | ||
| import org.apache.commons.lang3.StringUtils | ||
| import org.json4s.JsonDSL._ | ||
| import org.json4s.jackson.JsonMethods.{compact, render} | ||
|
|
||
| import org.apache.spark.{SparkEnv, SparkThrowable} | ||
| import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} | ||
| import org.apache.spark.api.python.PythonException | ||
| import org.apache.spark.connect.proto | ||
| import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession} | ||
| import org.apache.spark.sql.{Dataset, SparkSession} | ||
| import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT | ||
| import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} | ||
| import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, ExtendedMode, FormattedMode, SimpleMode} | ||
|
|
@@ -53,22 +59,46 @@ class SparkConnectService(debug: Boolean) | |
| extends SparkConnectServiceGrpc.SparkConnectServiceImplBase | ||
| with Logging { | ||
|
|
||
| private def buildStatusFromThrowable[A <: Throwable with SparkThrowable](st: A): RPCStatus = { | ||
| val t = Option(st.getCause).getOrElse(st) | ||
| private def allClasses(cl: Class[_]): Seq[Class[_]] = { | ||
| val classes = ArrayBuffer.empty[Class[_]] | ||
| if (cl != null && !cl.equals(classOf[java.lang.Object])) { | ||
| classes.append(cl) // Includes itself. | ||
| } | ||
|
|
||
| @tailrec | ||
| def appendSuperClasses(clazz: Class[_]): Unit = { | ||
| if (clazz == null || clazz.equals(classOf[java.lang.Object])) return | ||
| classes.append(clazz.getSuperclass) | ||
| appendSuperClasses(clazz.getSuperclass) | ||
| } | ||
|
|
||
| appendSuperClasses(cl) | ||
| classes.toSeq | ||
| } | ||
|
|
||
| private def buildStatusFromThrowable(st: Throwable): RPCStatus = { | ||
| RPCStatus | ||
| .newBuilder() | ||
| .setCode(RPCCode.INTERNAL_VALUE) | ||
| .addDetails( | ||
| ProtoAny.pack( | ||
| ErrorInfo | ||
| .newBuilder() | ||
| .setReason(t.getClass.getName) | ||
| .setReason(st.getClass.getName) | ||
| .setDomain("org.apache.spark") | ||
| .putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName)))) | ||
| .build())) | ||
| .setMessage(t.getLocalizedMessage) | ||
| .setMessage(StringUtils.abbreviate(st.getMessage, 2048)) | ||
| .build() | ||
| } | ||
|
|
||
| private def isPythonExecutionException(se: SparkException): Boolean = { | ||
| // See also pyspark.errors.exceptions.captured.convert_exception in PySpark. | ||
| se.getCause != null && se.getCause | ||
| .isInstanceOf[PythonException] && se.getCause.getStackTrace | ||
| .exists(_.toString.contains("org.apache.spark.sql.execution.python")) | ||
| } | ||
|
|
||
| /** | ||
| * Common exception handling function for the Analysis and Execution methods. Closes the stream | ||
| * after the error has been sent. | ||
|
|
@@ -83,46 +113,22 @@ class SparkConnectService(debug: Boolean) | |
| private def handleError[V]( | ||
| opType: String, | ||
| observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = { | ||
| case ae: AnalysisException => | ||
| logError(s"Error during: $opType", ae) | ||
| val status = RPCStatus | ||
| .newBuilder() | ||
| .setCode(RPCCode.INTERNAL_VALUE) | ||
| .addDetails( | ||
| ProtoAny.pack( | ||
| ErrorInfo | ||
| .newBuilder() | ||
| .setReason(ae.getClass.getName) | ||
| .setDomain("org.apache.spark") | ||
| .putMetadata("message", ae.getSimpleMessage) | ||
| .putMetadata("plan", Option(ae.plan).flatten.map(p => s"$p").getOrElse("")) | ||
| .build())) | ||
| .setMessage(ae.getLocalizedMessage) | ||
| .build() | ||
| observer.onError(StatusProto.toStatusRuntimeException(status)) | ||
| case st: SparkThrowable => | ||
| logError(s"Error during: $opType", st) | ||
| val status = buildStatusFromThrowable(st) | ||
| observer.onError(StatusProto.toStatusRuntimeException(status)) | ||
| case NonFatal(nf) => | ||
| logError(s"Error during: $opType", nf) | ||
| val status = RPCStatus | ||
| .newBuilder() | ||
| .setCode(RPCCode.INTERNAL_VALUE) | ||
| .addDetails( | ||
| ProtoAny.pack( | ||
| ErrorInfo | ||
| .newBuilder() | ||
| .setReason(nf.getClass.getName) | ||
| .setDomain("org.apache.spark") | ||
| .build())) | ||
| .setMessage(nf.getLocalizedMessage) | ||
| .build() | ||
| observer.onError(StatusProto.toStatusRuntimeException(status)) | ||
| case se: SparkException if isPythonExecutionException(se) => | ||
| logError(s"Error during: $opType", se) | ||
| observer.onError( | ||
| StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause))) | ||
|
||
|
|
||
| case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => | ||
| logError(s"Error during: $opType", e) | ||
| observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e))) | ||
|
|
||
| case e: Throwable => | ||
| logError(s"Error during: $opType", e) | ||
| observer.onError( | ||
| Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) | ||
| Status.UNKNOWN | ||
| .withCause(e) | ||
| .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) | ||
| .asRuntimeException()) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,25 +14,58 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| import json | ||
| from typing import Dict, Optional, TYPE_CHECKING | ||
|
|
||
| from typing import Dict, Optional | ||
|
|
||
| from pyspark.errors.exceptions.base import ( | ||
| AnalysisException as BaseAnalysisException, | ||
| IllegalArgumentException as BaseIllegalArgumentException, | ||
| ParseException as BaseParseException, | ||
| PySparkException, | ||
| PythonException as BasePythonException, | ||
| TempTableAlreadyExistsException as BaseTempTableAlreadyExistsException, | ||
| StreamingQueryException as BaseStreamingQueryException, | ||
| QueryExecutionException as BaseQueryExecutionException, | ||
| SparkUpgradeException as BaseSparkUpgradeException, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from google.rpc.error_details_pb2 import ErrorInfo | ||
|
|
||
|
|
||
| class SparkConnectException(PySparkException): | ||
| """ | ||
| Exception thrown from Spark Connect. | ||
| """ | ||
|
|
||
|
|
||
| def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException: | ||
| classes = [] | ||
| if "classes" in info.metadata: | ||
| classes = json.loads(info.metadata["classes"]) | ||
|
|
||
| if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: | ||
| return ParseException(message) | ||
| # Order matters. ParseException inherits AnalysisException. | ||
| elif "org.apache.spark.sql.AnalysisException" in classes: | ||
| return AnalysisException(message) | ||
| elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: | ||
| return StreamingQueryException(message) | ||
| elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: | ||
| return QueryExecutionException(message) | ||
| elif "java.lang.IllegalArgumentException" in classes: | ||
| return IllegalArgumentException(message) | ||
| elif "org.apache.spark.SparkUpgradeException" in classes: | ||
| return SparkUpgradeException(message) | ||
| elif "org.apache.spark.api.python.PythonException" in classes: | ||
| return PythonException( | ||
| "\n An exception was thrown from the Python worker. " | ||
| "Please see the stack trace below.\n%s" % message | ||
| ) | ||
| else: | ||
| return SparkConnectGrpcException(message, reason=info.reason) | ||
|
|
||
|
|
||
| class SparkConnectGrpcException(SparkConnectException): | ||
| """ | ||
| Base class to handle the errors from GRPC. | ||
|
|
@@ -61,41 +94,34 @@ class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): | |
| Failed to analyze a SQL query plan from Spark Connect server. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| message: Optional[str] = None, | ||
| error_class: Optional[str] = None, | ||
| message_parameters: Optional[Dict[str, str]] = None, | ||
| plan: Optional[str] = None, | ||
| reason: Optional[str] = None, | ||
| ) -> None: | ||
| self.message = message # type: ignore[assignment] | ||
| if plan is not None: | ||
| self.message = f"{self.message}\nPlan: {plan}" | ||
|
||
|
|
||
| super().__init__( | ||
| message=self.message, | ||
| error_class=error_class, | ||
| message_parameters=message_parameters, | ||
| reason=reason, | ||
| ) | ||
| class ParseException(SparkConnectGrpcException, BaseParseException): | ||
| """ | ||
| Failed to parse a SQL command from Spark Connect server. | ||
| """ | ||
|
|
||
|
|
||
| class TempTableAlreadyExistsException(AnalysisException, BaseTempTableAlreadyExistsException): | ||
| class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): | ||
| """ | ||
| Failed to create temp view from Spark Connect server since it is already exists. | ||
| Passed an illegal or inappropriate argument from Spark Connect server. | ||
| """ | ||
|
|
||
|
|
||
| class ParseException(SparkConnectGrpcException, BaseParseException): | ||
| class StreamingQueryException(SparkConnectGrpcException, BaseStreamingQueryException): | ||
| """ | ||
| Failed to parse a SQL command from Spark Connect server. | ||
| Exception that stopped a :class:`StreamingQuery` from Spark Connect server. | ||
| """ | ||
|
|
||
|
|
||
| class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): | ||
| class QueryExecutionException(SparkConnectGrpcException, BaseQueryExecutionException): | ||
| """ | ||
| Passed an illegal or inappropriate argument from Spark Connect server. | ||
| Failed to execute a query from Spark Connect server. | ||
| """ | ||
|
|
||
|
|
||
| class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException): | ||
| """ | ||
| Exception thrown because of Spark upgrade from Spark Connect | ||
| """ | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -356,7 +356,7 @@ def createTempView(self, name: str) -> None: | |
|
|
||
| Throw an exception if the table already exists. | ||
|
|
||
| >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL, +SKIP | ||
| >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL | ||
| Traceback (most recent call last): | ||
| ... | ||
| AnalysisException: "Temporary table 'people' already exists;" | ||
|
|
@@ -439,7 +439,7 @@ def createGlobalTempView(self, name: str) -> None: | |
|
|
||
| Throws an exception if the global temporary view already exists. | ||
|
|
||
| >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL, +SKIP | ||
| >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL | ||
| Traceback (most recent call last): | ||
| ... | ||
| AnalysisException: "Temporary table 'people' already exists;" | ||
|
|
@@ -4598,7 +4598,7 @@ def freqItems( | |
| Examples | ||
| -------- | ||
| >>> df = spark.createDataFrame([(1, 11), (1, 11), (3, 10), (4, 8), (4, 8)], ["c1", "c2"]) | ||
| >>> df.freqItems(["c1", "c2"]).show() # doctest: +SKIP | ||
| >>> df.freqItems(["c1", "c2"]).show() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, @HyukjinKwon and @ueshin . Unfortunately, this broke Scala 2.13 CI. I made a followup. |
||
| +------------+------------+ | ||
| |c1_freqItems|c2_freqItems| | ||
| +------------+------------+ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getLocalizedMessageis not used in our codebase.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and the doc of
setMessagementions that it's fine to send non-localized errors (and expect the client to localize it).