Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.sql.connect.service

import java.util.concurrent.TimeUnit

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.google.common.base.Ticker
Expand All @@ -31,12 +33,16 @@ import io.grpc.netty.NettyServerBuilder
import io.grpc.protobuf.StatusProto
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.StringUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.{SparkEnv, SparkThrowable}
import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, ExtendedMode, FormattedMode, SimpleMode}
Expand All @@ -53,22 +59,46 @@ class SparkConnectService(debug: Boolean)
extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
with Logging {

private def buildStatusFromThrowable[A <: Throwable with SparkThrowable](st: A): RPCStatus = {
val t = Option(st.getCause).getOrElse(st)
private def allClasses(cl: Class[_]): Seq[Class[_]] = {
val classes = ArrayBuffer.empty[Class[_]]
if (cl != null && !cl.equals(classOf[java.lang.Object])) {
classes.append(cl) // Includes itself.
}

@tailrec
def appendSuperClasses(clazz: Class[_]): Unit = {
if (clazz == null || clazz.equals(classOf[java.lang.Object])) return
classes.append(clazz.getSuperclass)
appendSuperClasses(clazz.getSuperclass)
}

appendSuperClasses(cl)
classes.toSeq
}

private def buildStatusFromThrowable(st: Throwable): RPCStatus = {
RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
.addDetails(
ProtoAny.pack(
ErrorInfo
.newBuilder()
.setReason(t.getClass.getName)
.setReason(st.getClass.getName)
.setDomain("org.apache.spark")
.putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName))))
.build()))
.setMessage(t.getLocalizedMessage)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getLocalizedMessage is not used in our codebase.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the doc of setMessage mentions that it's fine to send non-localized errors (and expect the client to localize it).

.setMessage(StringUtils.abbreviate(st.getMessage, 2048))
.build()
}

private def isPythonExecutionException(se: SparkException): Boolean = {
// See also pyspark.errors.exceptions.captured.convert_exception in PySpark.
se.getCause != null && se.getCause
.isInstanceOf[PythonException] && se.getCause.getStackTrace
.exists(_.toString.contains("org.apache.spark.sql.execution.python"))
}

/**
* Common exception handling function for the Analysis and Execution methods. Closes the stream
* after the error has been sent.
Expand All @@ -83,46 +113,22 @@ class SparkConnectService(debug: Boolean)
private def handleError[V](
opType: String,
observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = {
case ae: AnalysisException =>
logError(s"Error during: $opType", ae)
val status = RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
.addDetails(
ProtoAny.pack(
ErrorInfo
.newBuilder()
.setReason(ae.getClass.getName)
.setDomain("org.apache.spark")
.putMetadata("message", ae.getSimpleMessage)
.putMetadata("plan", Option(ae.plan).flatten.map(p => s"$p").getOrElse(""))
.build()))
.setMessage(ae.getLocalizedMessage)
.build()
observer.onError(StatusProto.toStatusRuntimeException(status))
case st: SparkThrowable =>
logError(s"Error during: $opType", st)
val status = buildStatusFromThrowable(st)
observer.onError(StatusProto.toStatusRuntimeException(status))
case NonFatal(nf) =>
logError(s"Error during: $opType", nf)
val status = RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
.addDetails(
ProtoAny.pack(
ErrorInfo
.newBuilder()
.setReason(nf.getClass.getName)
.setDomain("org.apache.spark")
.build()))
.setMessage(nf.getLocalizedMessage)
.build()
observer.onError(StatusProto.toStatusRuntimeException(status))
case se: SparkException if isPythonExecutionException(se) =>
logError(s"Error during: $opType", se)
observer.onError(
StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause)))
Copy link
Member Author

@HyukjinKwon HyukjinKwon Feb 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

from pyspark.sql.functions import udf
@udf
def aa(a):
    1/0

spark.range(1).select(aa("id")).show()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../spark/python/pyspark/sql/connect/dataframe.py", line 776, in show
    print(self._show_string(n, truncate, vertical))
  File "/.../spark/python/pyspark/sql/connect/dataframe.py", line 619, in _show_string
    pdf = DataFrame.withPlan(
  File "/.../spark/python/pyspark/sql/connect/dataframe.py", line 1325, in toPandas
    return self._session.client.to_pandas(query)
  File "/.../spark/python/pyspark/sql/connect/client.py", line 449, in to_pandas
    table, metrics = self._execute_and_fetch(req)
  File "/.../spark/python/pyspark/sql/connect/client.py", line 636, in _execute_and_fetch
    self._handle_error(rpc_error)
  File "/.../spark/python/pyspark/sql/connect/client.py", line 670, in _handle_error
    raise convert_exception(info, status.message) from None
pyspark.errors.exceptions.connect.PythonException:
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "<stdin>", line 3, in aa
ZeroDivisionError: division by zero


case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) =>
logError(s"Error during: $opType", e)
observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e)))

case e: Throwable =>
logError(s"Error during: $opType", e)
observer.onError(
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
Status.UNKNOWN
.withCause(e)
.withDescription(StringUtils.abbreviate(e.getMessage, 2048))
.asRuntimeException())
}

/**
Expand Down
76 changes: 51 additions & 25 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from typing import Dict, Optional, TYPE_CHECKING

from typing import Dict, Optional

from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
IllegalArgumentException as BaseIllegalArgumentException,
ParseException as BaseParseException,
PySparkException,
PythonException as BasePythonException,
TempTableAlreadyExistsException as BaseTempTableAlreadyExistsException,
StreamingQueryException as BaseStreamingQueryException,
QueryExecutionException as BaseQueryExecutionException,
SparkUpgradeException as BaseSparkUpgradeException,
)

if TYPE_CHECKING:
from google.rpc.error_details_pb2 import ErrorInfo


class SparkConnectException(PySparkException):
"""
Exception thrown from Spark Connect.
"""


def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException:
classes = []
if "classes" in info.metadata:
classes = json.loads(info.metadata["classes"])

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(message)
# Order matters. ParseException inherits AnalysisException.
elif "org.apache.spark.sql.AnalysisException" in classes:
return AnalysisException(message)
elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
return StreamingQueryException(message)
elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
return QueryExecutionException(message)
elif "java.lang.IllegalArgumentException" in classes:
return IllegalArgumentException(message)
elif "org.apache.spark.SparkUpgradeException" in classes:
return SparkUpgradeException(message)
elif "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
"\n An exception was thrown from the Python worker. "
"Please see the stack trace below.\n%s" % message
)
else:
return SparkConnectGrpcException(message, reason=info.reason)


class SparkConnectGrpcException(SparkConnectException):
"""
Base class to handle the errors from GRPC.
Expand Down Expand Up @@ -61,41 +94,34 @@ class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
Failed to analyze a SQL query plan from Spark Connect server.
"""

def __init__(
self,
message: Optional[str] = None,
error_class: Optional[str] = None,
message_parameters: Optional[Dict[str, str]] = None,
plan: Optional[str] = None,
reason: Optional[str] = None,
) -> None:
self.message = message # type: ignore[assignment]
if plan is not None:
self.message = f"{self.message}\nPlan: {plan}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original AnalysisException.getMessage contains the string representation of the underlying plan.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

spark.range(1).select("a").show()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../spark/python/pyspark/sql/connect/dataframe.py", line 776, in show
    print(self._show_string(n, truncate, vertical))
  File "/.../spark/python/pyspark/sql/connect/dataframe.py", line 619, in _show_string
    pdf = DataFrame.withPlan(
  File "/.../spark/python/pyspark/sql/connect/dataframe.py", line 1325, in toPandas
    return self._session.client.to_pandas(query)
  File "/.../spark/python/pyspark/sql/connect/client.py", line 449, in to_pandas
    table, metrics = self._execute_and_fetch(req)
  File "/.../spark/python/pyspark/sql/connect/client.py", line 636, in _execute_and_fetch
    self._handle_error(rpc_error)
  File "/.../spark/python/pyspark/sql/connect/client.py", line 670, in _handle_error
    raise convert_exception(info, status.message) from None
pyspark.errors.exceptions.connect.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `a` cannot be resolved. Did you mean one of the following? [`id`].;
'Project ['a]
+- Range (0, 1, step=1, splits=Some(16))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the captured one's stack trace like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Captured one:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../spark/python/pyspark/sql/dataframe.py", line 2987, in select
    jdf = self._jdf.select(self._jcols(*cols))
  File "/.../spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1322, in __call__
  File "/.../sparkk/python/pyspark/errors/exceptions/captured.py", line 159, in deco
    raise converted from None
pyspark.errors.exceptions.captured.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `a` cannot be resolved. Did you mean one of the following? [`id`].;
'Project ['a]
+- Range (0, 1, step=1, splits=Some(16))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, should we still show the plan to be consistent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh, yeah. It still shows the plan. This is the part of getMessage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. 👍


super().__init__(
message=self.message,
error_class=error_class,
message_parameters=message_parameters,
reason=reason,
)
class ParseException(SparkConnectGrpcException, BaseParseException):
"""
Failed to parse a SQL command from Spark Connect server.
"""


class TempTableAlreadyExistsException(AnalysisException, BaseTempTableAlreadyExistsException):
class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException):
"""
Failed to create temp view from Spark Connect server since it is already exists.
Passed an illegal or inappropriate argument from Spark Connect server.
"""


class ParseException(SparkConnectGrpcException, BaseParseException):
class StreamingQueryException(SparkConnectGrpcException, BaseStreamingQueryException):
"""
Failed to parse a SQL command from Spark Connect server.
Exception that stopped a :class:`StreamingQuery` from Spark Connect server.
"""


class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException):
class QueryExecutionException(SparkConnectGrpcException, BaseQueryExecutionException):
"""
Passed an illegal or inappropriate argument from Spark Connect server.
Failed to execute a query from Spark Connect server.
"""


class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException):
"""
Exception thrown because of Spark upgrade from Spark Connect
"""


Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def getTable(self, tableName: str) -> Table:

Throw an analysis exception when the table does not exist.

>>> spark.catalog.getTable("tbl1") # doctest: +SKIP
>>> spark.catalog.getTable("tbl1")
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down Expand Up @@ -548,7 +548,7 @@ def getFunction(self, functionName: str) -> Function:

Throw an analysis exception when the function does not exists.

>>> spark.catalog.getFunction("my_func2") # doctest: +SKIP
>>> spark.catalog.getFunction("my_func2")
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down Expand Up @@ -867,7 +867,7 @@ def dropTempView(self, viewName: str) -> bool:

Throw an exception if the temporary view does not exists.

>>> spark.table("my_table") # doctest: +SKIP
>>> spark.table("my_table")
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down Expand Up @@ -907,7 +907,7 @@ def dropGlobalTempView(self, viewName: str) -> bool:

Throw an exception if the global view does not exists.

>>> spark.table("global_temp.my_table") # doctest: +SKIP
>>> spark.table("global_temp.my_table")
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down
33 changes: 2 additions & 31 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
from pyspark.errors.exceptions.connect import (
AnalysisException,
ParseException,
PythonException,
convert_exception,
SparkConnectException,
SparkConnectGrpcException,
TempTableAlreadyExistsException,
IllegalArgumentException,
)
from pyspark.sql.connect.expressions import (
PythonUDF,
Expand Down Expand Up @@ -730,32 +726,7 @@ def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
info = error_details_pb2.ErrorInfo()
d.Unpack(info)
reason = info.reason
if reason == "org.apache.spark.sql.AnalysisException":
raise AnalysisException(
info.metadata["message"], plan=info.metadata["plan"]
) from None
elif reason == "org.apache.spark.sql.catalyst.parser.ParseException":
raise ParseException(info.metadata["message"]) from None
elif (
reason
== "org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException"
):
raise TempTableAlreadyExistsException(
info.metadata["message"], plan=info.metadata["plan"]
) from None
elif reason == "java.lang.IllegalArgumentException":
message = info.metadata["message"]
message = message if message != "" else status.message
raise IllegalArgumentException(message) from None
elif reason == "org.apache.spark.api.python.PythonException":
message = info.metadata["message"]
message = message if message != "" else status.message
raise PythonException(message) from None
else:
raise SparkConnectGrpcException(
status.message, reason=info.reason
) from None
raise convert_exception(info, status.message) from None

raise SparkConnectGrpcException(status.message) from None
else:
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def createTempView(self, name: str) -> None:

Throw an exception if the table already exists.

>>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL, +SKIP
>>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AnalysisException: "Temporary table 'people' already exists;"
Expand Down Expand Up @@ -439,7 +439,7 @@ def createGlobalTempView(self, name: str) -> None:

Throws an exception if the global temporary view already exists.

>>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL, +SKIP
>>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AnalysisException: "Temporary table 'people' already exists;"
Expand Down Expand Up @@ -4598,7 +4598,7 @@ def freqItems(
Examples
--------
>>> df = spark.createDataFrame([(1, 11), (1, 11), (3, 10), (4, 8), (4, 8)], ["c1", "c2"])
>>> df.freqItems(["c1", "c2"]).show() # doctest: +SKIP
>>> df.freqItems(["c1", "c2"]).show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+------------+------------+
|c1_freqItems|c2_freqItems|
+------------+------------+
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
AnalysisException,
ParseException,
SparkConnectException,
TempTableAlreadyExistsException,
)

if should_test_connect:
Expand Down Expand Up @@ -1244,7 +1243,7 @@ def test_create_global_temp_view(self):

# Test when creating a view which is already exists but
self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
with self.assertRaises(TempTableAlreadyExistsException):
with self.assertRaises(AnalysisException):
self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")

def test_create_session_local_temp_view(self):
Expand All @@ -1256,7 +1255,7 @@ def test_create_session_local_temp_view(self):
self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0)

# Test when creating a view which is already exists but
with self.assertRaises(TempTableAlreadyExistsException):
with self.assertRaises(AnalysisException):
self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp")

def test_to_pandas(self):
Expand Down
Loading