From 25ab26bea145329d28a058b6b551e460f4749c2f Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 9 Feb 2023 12:12:10 +0900 Subject: [PATCH] Better exception --- .../connect/service/SparkConnectService.scala | 92 ++++++++++--------- python/pyspark/errors/exceptions/connect.py | 76 ++++++++++----- python/pyspark/sql/catalog.py | 8 +- python/pyspark/sql/connect/client.py | 33 +------ python/pyspark/sql/dataframe.py | 6 +- .../sql/tests/connect/test_connect_basic.py | 5 +- .../sql/tests/pandas/test_pandas_udf.py | 6 +- python/pyspark/sql/tests/test_catalog.py | 16 ++-- 8 files changed, 122 insertions(+), 120 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 25b7009860b71..05aa2428140b0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.TimeUnit +import scala.annotation.tailrec import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Ticker @@ -31,12 +33,16 @@ import io.grpc.netty.NettyServerBuilder import io.grpc.protobuf.StatusProto import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver +import org.apache.commons.lang3.StringUtils +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.{SparkEnv, SparkThrowable} +import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} +import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, ExtendedMode, FormattedMode, SimpleMode} @@ -53,8 +59,24 @@ class SparkConnectService(debug: Boolean) extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { - private def buildStatusFromThrowable[A <: Throwable with SparkThrowable](st: A): RPCStatus = { - val t = Option(st.getCause).getOrElse(st) + private def allClasses(cl: Class[_]): Seq[Class[_]] = { + val classes = ArrayBuffer.empty[Class[_]] + if (cl != null && !cl.equals(classOf[java.lang.Object])) { + classes.append(cl) // Includes itself. + } + + @tailrec + def appendSuperClasses(clazz: Class[_]): Unit = { + if (clazz == null || clazz.equals(classOf[java.lang.Object])) return + classes.append(clazz.getSuperclass) + appendSuperClasses(clazz.getSuperclass) + } + + appendSuperClasses(cl) + classes.toSeq + } + + private def buildStatusFromThrowable(st: Throwable): RPCStatus = { RPCStatus .newBuilder() .setCode(RPCCode.INTERNAL_VALUE) @@ -62,13 +84,21 @@ class SparkConnectService(debug: Boolean) ProtoAny.pack( ErrorInfo .newBuilder() - .setReason(t.getClass.getName) + .setReason(st.getClass.getName) .setDomain("org.apache.spark") + .putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName)))) .build())) - .setMessage(t.getLocalizedMessage) + .setMessage(StringUtils.abbreviate(st.getMessage, 2048)) .build() } + private def isPythonExecutionException(se: SparkException): Boolean = { + // See also pyspark.errors.exceptions.captured.convert_exception in PySpark. + se.getCause != null && se.getCause + .isInstanceOf[PythonException] && se.getCause.getStackTrace + .exists(_.toString.contains("org.apache.spark.sql.execution.python")) + } + /** * Common exception handling function for the Analysis and Execution methods. Closes the stream * after the error has been sent. @@ -83,46 +113,22 @@ class SparkConnectService(debug: Boolean) private def handleError[V]( opType: String, observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = { - case ae: AnalysisException => - logError(s"Error during: $opType", ae) - val status = RPCStatus - .newBuilder() - .setCode(RPCCode.INTERNAL_VALUE) - .addDetails( - ProtoAny.pack( - ErrorInfo - .newBuilder() - .setReason(ae.getClass.getName) - .setDomain("org.apache.spark") - .putMetadata("message", ae.getSimpleMessage) - .putMetadata("plan", Option(ae.plan).flatten.map(p => s"$p").getOrElse("")) - .build())) - .setMessage(ae.getLocalizedMessage) - .build() - observer.onError(StatusProto.toStatusRuntimeException(status)) - case st: SparkThrowable => - logError(s"Error during: $opType", st) - val status = buildStatusFromThrowable(st) - observer.onError(StatusProto.toStatusRuntimeException(status)) - case NonFatal(nf) => - logError(s"Error during: $opType", nf) - val status = RPCStatus - .newBuilder() - .setCode(RPCCode.INTERNAL_VALUE) - .addDetails( - ProtoAny.pack( - ErrorInfo - .newBuilder() - .setReason(nf.getClass.getName) - .setDomain("org.apache.spark") - .build())) - .setMessage(nf.getLocalizedMessage) - .build() - observer.onError(StatusProto.toStatusRuntimeException(status)) + case se: SparkException if isPythonExecutionException(se) => + logError(s"Error during: $opType", se) + observer.onError( + StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause))) + + case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => + logError(s"Error during: $opType", e) + observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e))) + case e: Throwable => logError(s"Error during: $opType", e) observer.onError( - Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) + Status.UNKNOWN + .withCause(e) + .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) + .asRuntimeException()) } /** diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index ba3bc9f7576b7..f5f1d42ca5d5f 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json +from typing import Dict, Optional, TYPE_CHECKING -from typing import Dict, Optional from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, @@ -23,9 +24,14 @@ ParseException as BaseParseException, PySparkException, PythonException as BasePythonException, - TempTableAlreadyExistsException as BaseTempTableAlreadyExistsException, + StreamingQueryException as BaseStreamingQueryException, + QueryExecutionException as BaseQueryExecutionException, + SparkUpgradeException as BaseSparkUpgradeException, ) +if TYPE_CHECKING: + from google.rpc.error_details_pb2 import ErrorInfo + class SparkConnectException(PySparkException): """ @@ -33,6 +39,33 @@ class SparkConnectException(PySparkException): """ +def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException: + classes = [] + if "classes" in info.metadata: + classes = json.loads(info.metadata["classes"]) + + if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: + return ParseException(message) + # Order matters. ParseException inherits AnalysisException. + elif "org.apache.spark.sql.AnalysisException" in classes: + return AnalysisException(message) + elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: + return StreamingQueryException(message) + elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: + return QueryExecutionException(message) + elif "java.lang.IllegalArgumentException" in classes: + return IllegalArgumentException(message) + elif "org.apache.spark.SparkUpgradeException" in classes: + return SparkUpgradeException(message) + elif "org.apache.spark.api.python.PythonException" in classes: + return PythonException( + "\n An exception was thrown from the Python worker. " + "Please see the stack trace below.\n%s" % message + ) + else: + return SparkConnectGrpcException(message, reason=info.reason) + + class SparkConnectGrpcException(SparkConnectException): """ Base class to handle the errors from GRPC. @@ -61,41 +94,34 @@ class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): Failed to analyze a SQL query plan from Spark Connect server. """ - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - plan: Optional[str] = None, - reason: Optional[str] = None, - ) -> None: - self.message = message # type: ignore[assignment] - if plan is not None: - self.message = f"{self.message}\nPlan: {plan}" - super().__init__( - message=self.message, - error_class=error_class, - message_parameters=message_parameters, - reason=reason, - ) +class ParseException(SparkConnectGrpcException, BaseParseException): + """ + Failed to parse a SQL command from Spark Connect server. + """ -class TempTableAlreadyExistsException(AnalysisException, BaseTempTableAlreadyExistsException): +class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): """ - Failed to create temp view from Spark Connect server since it is already exists. + Passed an illegal or inappropriate argument from Spark Connect server. """ -class ParseException(SparkConnectGrpcException, BaseParseException): +class StreamingQueryException(SparkConnectGrpcException, BaseStreamingQueryException): """ - Failed to parse a SQL command from Spark Connect server. + Exception that stopped a :class:`StreamingQuery` from Spark Connect server. """ -class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): +class QueryExecutionException(SparkConnectGrpcException, BaseQueryExecutionException): """ - Passed an illegal or inappropriate argument from Spark Connect server. + Failed to execute a query from Spark Connect server. + """ + + +class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException): + """ + Exception thrown because of Spark upgrade from Spark Connect """ diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 6deee786164de..a7f3e761f3f46 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -387,7 +387,7 @@ def getTable(self, tableName: str) -> Table: Throw an analysis exception when the table does not exist. - >>> spark.catalog.getTable("tbl1") # doctest: +SKIP + >>> spark.catalog.getTable("tbl1") Traceback (most recent call last): ... AnalysisException: ... @@ -548,7 +548,7 @@ def getFunction(self, functionName: str) -> Function: Throw an analysis exception when the function does not exists. - >>> spark.catalog.getFunction("my_func2") # doctest: +SKIP + >>> spark.catalog.getFunction("my_func2") Traceback (most recent call last): ... AnalysisException: ... @@ -867,7 +867,7 @@ def dropTempView(self, viewName: str) -> bool: Throw an exception if the temporary view does not exists. - >>> spark.table("my_table") # doctest: +SKIP + >>> spark.table("my_table") Traceback (most recent call last): ... AnalysisException: ... @@ -907,7 +907,7 @@ def dropGlobalTempView(self, viewName: str) -> bool: Throw an exception if the global view does not exists. - >>> spark.table("global_temp.my_table") # doctest: +SKIP + >>> spark.table("global_temp.my_table") Traceback (most recent call last): ... AnalysisException: ... diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 903981a015bed..943a7e70464a0 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -60,13 +60,9 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib import pyspark.sql.connect.types as types from pyspark.errors.exceptions.connect import ( - AnalysisException, - ParseException, - PythonException, + convert_exception, SparkConnectException, SparkConnectGrpcException, - TempTableAlreadyExistsException, - IllegalArgumentException, ) from pyspark.sql.connect.expressions import ( PythonUDF, @@ -730,32 +726,7 @@ def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn: if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): info = error_details_pb2.ErrorInfo() d.Unpack(info) - reason = info.reason - if reason == "org.apache.spark.sql.AnalysisException": - raise AnalysisException( - info.metadata["message"], plan=info.metadata["plan"] - ) from None - elif reason == "org.apache.spark.sql.catalyst.parser.ParseException": - raise ParseException(info.metadata["message"]) from None - elif ( - reason - == "org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException" - ): - raise TempTableAlreadyExistsException( - info.metadata["message"], plan=info.metadata["plan"] - ) from None - elif reason == "java.lang.IllegalArgumentException": - message = info.metadata["message"] - message = message if message != "" else status.message - raise IllegalArgumentException(message) from None - elif reason == "org.apache.spark.api.python.PythonException": - message = info.metadata["message"] - message = message if message != "" else status.message - raise PythonException(message) from None - else: - raise SparkConnectGrpcException( - status.message, reason=info.reason - ) from None + raise convert_exception(info, status.message) from None raise SparkConnectGrpcException(status.message) from None else: diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e794bb94e7509..5649d362b8b34 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -356,7 +356,7 @@ def createTempView(self, name: str) -> None: Throw an exception if the table already exists. - >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL, +SKIP + >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... AnalysisException: "Temporary table 'people' already exists;" @@ -439,7 +439,7 @@ def createGlobalTempView(self, name: str) -> None: Throws an exception if the global temporary view already exists. - >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL, +SKIP + >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... AnalysisException: "Temporary table 'people' already exists;" @@ -4598,7 +4598,7 @@ def freqItems( Examples -------- >>> df = spark.createDataFrame([(1, 11), (1, 11), (3, 10), (4, 8), (4, 8)], ["c1", "c2"]) - >>> df.freqItems(["c1", "c2"]).show() # doctest: +SKIP + >>> df.freqItems(["c1", "c2"]).show() +------------+------------+ |c1_freqItems|c2_freqItems| +------------+------------+ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index b8e2c7b151a04..b3b241b2d4e3a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -53,7 +53,6 @@ AnalysisException, ParseException, SparkConnectException, - TempTableAlreadyExistsException, ) if should_test_connect: @@ -1244,7 +1243,7 @@ def test_create_global_temp_view(self): # Test when creating a view which is already exists but self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - with self.assertRaises(TempTableAlreadyExistsException): + with self.assertRaises(AnalysisException): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") def test_create_session_local_temp_view(self): @@ -1256,7 +1255,7 @@ def test_create_session_local_temp_view(self): self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0) # Test when creating a view which is already exists but - with self.assertRaises(TempTableAlreadyExistsException): + with self.assertRaises(AnalysisException): self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp") def test_to_pandas(self): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 1b3b4555d7ffd..0f92711313040 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -171,7 +171,7 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - exc_message = "Caught StopIteration thrown from user's code; failing the task" + exc_message = "StopIteration" df = self.spark.range(0, 100) # plain udf (test for SPARK-23754) @@ -193,7 +193,7 @@ def foo(x): def foofoo(x, y): raise StopIteration() - exc_message = "Caught StopIteration thrown from user's code; failing the task" + exc_message = "StopIteration" df = self.spark.range(0, 100) # pandas grouped map @@ -215,7 +215,7 @@ def test_stopiteration_in_grouped_agg(self): def foo(x): raise StopIteration() - exc_message = "Caught StopIteration thrown from user's code; failing the task" + exc_message = "StopIteration" df = self.spark.range(0, 100) # pandas grouped agg diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 4ab11c460717a..10f3ec12c9cf7 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from pyspark.errors import AnalysisException from pyspark.sql.types import StructType, StructField, IntegerType from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -28,9 +28,7 @@ def test_current_database(self): spark.catalog.setCurrentDatabase("some_db") self.assertEqual(spark.catalog.currentDatabase(), "some_db") self.assertRaisesRegex( - # TODO(SPARK-41715): Should catch specific exceptions for both - # Spark Connect and PySpark - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.setCurrentDatabase("does_not_exist"), ) @@ -181,7 +179,7 @@ def compareTables(t1, t2): ) ) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.listTables("does_not_exist"), ) @@ -236,7 +234,7 @@ def test_list_functions(self): self.assertTrue("func1" not in newFunctionsSomeDb) self.assertTrue("func2" in newFunctionsSomeDb) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.listFunctions("does_not_exist"), ) @@ -333,9 +331,11 @@ def test_list_columns(self): isBucket=False, ), ) - self.assertRaisesRegex(Exception, "tab2", lambda: spark.catalog.listColumns("tab2")) self.assertRaisesRegex( - Exception, + AnalysisException, "tab2", lambda: spark.catalog.listColumns("tab2") + ) + self.assertRaisesRegex( + AnalysisException, "does_not_exist", lambda: spark.catalog.listColumns("does_not_exist"), )