Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,24 @@ private[connect] object ProtoUtils {
private val MAX_BYTES_SIZE = 8
private val MAX_STRING_SIZE = 1024

def abbreviate(message: Message): Message = {
def abbreviate(message: Message, maxStringSize: Int = MAX_STRING_SIZE): Message = {
val builder = message.toBuilder

message.getAllFields.asScala.iterator.foreach {
case (field: FieldDescriptor, string: String)
if field.getJavaType == FieldDescriptor.JavaType.STRING && string != null =>
val size = string.size
if (size > MAX_STRING_SIZE) {
builder.setField(field, createString(string.take(MAX_STRING_SIZE), size))
if (size > maxStringSize) {
builder.setField(field, createString(string.take(maxStringSize), size))
} else {
builder.setField(field, string)
}

case (field: FieldDescriptor, byteString: ByteString)
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null =>
val size = byteString.size
if (size > MAX_BYTES_SIZE) {
val prefix = Array.tabulate(MAX_BYTES_SIZE)(byteString.byteAt)
if (size > maxStringSize) {
val prefix = Array.tabulate(maxStringSize)(byteString.byteAt)
builder.setField(field, createByteString(prefix, size))
} else {
builder.setField(field, byteString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
// and different exceptions like InterruptedException, ClosedByInterruptException etc.
// could be thrown.
if (interrupted) {
// Turn the interrupt into OPERATION_CANCELED error.
throw new SparkSQLException("OPERATION_CANCELED", Map.empty)
} else {
// Rethrown the original error.
Expand All @@ -92,7 +91,9 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
"execute",
executeHolder.responseObserver,
executeHolder.sessionHolder.userId,
executeHolder.sessionHolder.sessionId)
executeHolder.sessionHolder.sessionId,
Some(executeHolder.eventsManager),
interrupted)
}
}

Expand Down Expand Up @@ -148,9 +149,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
val planner = new SparkConnectPlanner(executeHolder.sessionHolder)
planner.process(
command = command,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId,
responseObserver = responseObserver)
responseObserver = responseObserver,
executeHolder = executeHolder)
responseObserver.onCompleted()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.execution

import scala.collection.JavaConverters._
import scala.util.{Failure, Success}

import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
Expand Down Expand Up @@ -54,13 +55,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
throw new IllegalStateException(
s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.")
}

// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(sessionHolder)
val tracker = executeHolder.eventsManager.createQueryPlanningTracker
val dataframe =
Dataset.ofRows(sessionHolder.session, planner.transformRelation(request.getPlan.getRoot))
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot),
tracker)
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(request.getSessionId, dataframe, responseObserver)
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(
MetricGenerator.createMetricsResponse(request.getSessionId, dataframe))
if (dataframe.queryExecution.observedMetrics.nonEmpty) {
Expand All @@ -87,10 +90,11 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
batches.map(b => b -> batches.rowCountInLastBatch)
}

private def processAsArrowBatches(
sessionId: String,
def processAsArrowBatches(
dataframe: DataFrame,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
responseObserver: StreamObserver[ExecutePlanResponse],
executePlan: ExecuteHolder): Unit = {
val sessionId = executePlan.sessionHolder.sessionId
val spark = dataframe.sparkSession
val schema = dataframe.schema
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
Expand Down Expand Up @@ -120,6 +124,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)

dataframe.queryExecution.executedPlan match {
case LocalTableScanExec(_, rows) =>
executePlan.eventsManager.postFinished()
converter(rows.iterator).foreach { case (bytes, count) =>
sendBatch(bytes, count)
}
Expand Down Expand Up @@ -156,13 +161,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
resultFunc = () => ())

// Collect errors and propagate them to the main thread.
future.onComplete { result =>
result.failed.foreach { throwable =>
future.onComplete {
case Success(_) =>
executePlan.eventsManager.postFinished()
case Failure(throwable) =>
signal.synchronized {
error = Some(throwable)
signal.notify()
}
}
}(ThreadUtils.sameThread)

// The main thread will wait until 0-th partition is available,
Expand Down
Loading