From 956f5c4915f98b33921c78531973265fa51a8c71 Mon Sep 17 00:00:00 2001 From: Lingkai Kong Date: Fri, 11 Aug 2023 11:01:28 -0400 Subject: [PATCH 1/4] Add ProducedRowCount to SparkListenerConnectOperationFinished --- .../execution/SparkConnectPlanExecution.scala | 8 ++++--- .../service/ExecuteEventsManager.scala | 23 +++++++++++++++++-- .../planner/SparkConnectServiceSuite.scala | 5 ++-- .../service/ExecuteEventsManagerSuite.scala | 13 +++++++++++ 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 131ddf76fa44..a443387587ed 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -110,6 +110,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) errorOnDuplicatedFieldNames = false) var numSent = 0 + var totalNumRows: Long = 0 def sendBatch(bytes: Array[Byte], count: Long): Unit = { val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId) val batch = proto.ExecutePlanResponse.ArrowBatch @@ -120,11 +121,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) response.setArrowBatch(batch) responseObserver.onNext(response.build()) numSent += 1 + totalNumRows += count } dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => - executePlan.eventsManager.postFinished() + executePlan.eventsManager.postFinished(Some(totalNumRows)) converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } @@ -163,7 +165,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) // Collect errors and propagate them to the main thread. .andThen { case Success(_) => - executePlan.eventsManager.postFinished() + executePlan.eventsManager.postFinished(Some(totalNumRows)) case Failure(throwable) => signal.synchronized { error = Some(throwable) @@ -201,7 +203,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) } ThreadUtils.awaitReady(future, Duration.Inf) } else { - executePlan.eventsManager.postFinished() + executePlan.eventsManager.postFinished(Some(totalNumRows)) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index 5e831aaa98f2..c1293e30c1ea 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -75,6 +75,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { private var canceled = Option.empty[Boolean] + private var producedRowCount = Option.empty[Long] + /** * @return * Last event posted by the Connect request @@ -95,6 +97,13 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { */ private[connect] def hasError: Option[Boolean] = error + /** + * @return + * How many rows the Connect request has produced @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished + */ + private[connect] def getProducedRowCount: Option[Long] = producedRowCount + /** * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted. */ @@ -193,12 +202,21 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { /** * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished. */ - def postFinished(): Unit = { + def postFinished(totalNumRowsOpt: Option[Long] = None): Unit = { assertStatus( List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution), ExecuteStatus.Finished) + producedRowCount = totalNumRowsOpt + listenerBus - .post(SparkListenerConnectOperationFinished(jobTag, operationId, clock.getTimeMillis())) + .post( + SparkListenerConnectOperationFinished( + jobTag, + operationId, + clock.getTimeMillis(), + producedRowCount + ) + ) } /** @@ -402,6 +420,7 @@ case class SparkListenerConnectOperationFinished( jobTag: String, operationId: String, eventTime: Long, + producedRowCount: Option[Long] = None, extraTags: Map[String, String] = Map.empty) extends SparkListenerEvent diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 285f3103b190..01d97a5be013 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -190,7 +190,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with done = true } }) - verifyEvents.onCompleted() + verifyEvents.onCompleted(Some(100)) // The current implementation is expected to be blocking. This is here to make sure it is. assert(done) @@ -788,8 +788,9 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with assert(executeHolder.eventsManager.hasCanceled.isEmpty) assert(executeHolder.eventsManager.hasError.isDefined) } - def onCompleted(): Unit = { + def onCompleted(producedRowCount: Option[Long] = None): Unit = { assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index e7cc80071427..1a3fea390faa 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -141,6 +141,19 @@ class ExecuteEventsManagerSuite DEFAULT_CLOCK.getTimeMillis())) } + test("SPARK-44776: post finished with row number") { + val events = setupEvents(ExecuteStatus.Started) + events.postFinished(Some(100)) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFinished( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + Some(100) + )) + } + test("SPARK-43923: post closed") { val events = setupEvents(ExecuteStatus.Finished) events.postClosed() From 2e3252d61fc4ea046efaff134178b031ca73a20a Mon Sep 17 00:00:00 2001 From: Lingkai Kong Date: Mon, 14 Aug 2023 16:09:53 -0400 Subject: [PATCH 2/4] Add row for command and add test --- .../execution/SparkConnectPlanExecution.scala | 2 +- .../connect/planner/SparkConnectPlanner.scala | 2 +- .../planner/SparkConnectServiceSuite.scala | 144 ++++++++++++++++-- 3 files changed, 130 insertions(+), 18 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index a443387587ed..b03f4e3a0285 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -126,10 +126,10 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => - executePlan.eventsManager.postFinished(Some(totalNumRows)) converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } + executePlan.eventsManager.postFinished(Some(totalNumRows)) case _ => SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { val rows = dataframe.queryExecution.executedPlan.execute() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 45f962f79202..992e0319cf7f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2480,7 +2480,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .putAllArgs(getSqlCommand.getArgsMap) .addAllPosArgs(getSqlCommand.getPosArgsList))) } - executeHolder.eventsManager.postFinished() + executeHolder.eventsManager.postFinished(Some(rows.size)) // Exactly one SQL Command Result Batch responseObserver.onNext( ExecutePlanResponse diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 01d97a5be013..400d8a3ebf1c 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.CreateDataFrameViewCommand import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.dsl.MockRemoteSession @@ -49,13 +50,18 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, Sessi import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** * Testing Connect Service implementation. */ -class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with Logging { +class SparkConnectServiceSuite + extends SharedSparkSession + with MockitoSugar + with Logging + with SparkConnectPlanTest{ private def sparkSessionHolder = SessionHolder.forTesting(spark) private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093") @@ -238,6 +244,79 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } + test("SPARK-44776: LocalTableScanExec") { + withEvents { verifyEvents => + // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 + assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + + val rows = (0L to 5L).map { i => + new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar))) + } + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + + val localRelation = createLocalRelationProto(schema, inputRows) + val plan = proto.Plan + .newBuilder() + .setRoot( + localRelation + ) + .build() + + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted(Some(6)) + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 1 Partitions + Metrics + assert(responses.size == 3) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) + } + } + test("SPARK-44657: Arrow batches respect max batch size limit") { // Set 10 KiB as the batch size limit val batchSize = 10 * 1024 @@ -301,13 +380,20 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with gridTest("SPARK-43923: commands send events")( Seq( - proto.Command + ( + proto.Command .newBuilder() .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), - proto.Command + Some(0L) + ), + ( + proto.Command .newBuilder() - .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show tables").build()), - proto.Command + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show databases").build()), + Some(1L) + ), + ( + proto.Command .newBuilder() .setWriteOperation( proto.WriteOperation @@ -316,7 +402,10 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) .setPath(Utils.createTempDir().getAbsolutePath) .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), - proto.Command + None + ), + ( + proto.Command .newBuilder() .setWriteOperationV2( proto.WriteOperationV2 @@ -325,7 +414,10 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) .setTableName("testcat.testtable") .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), - proto.Command + None + ), + ( + proto.Command .newBuilder() .setCreateDataframeView( CreateDataFrameViewCommand @@ -333,10 +425,14 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setName("testview") .setInput( proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), - proto.Command + None + ), + (proto.Command .newBuilder() .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), - proto.Command + None), + ( + proto.Command .newBuilder() .setExtension( protobuf.Any.pack( @@ -344,7 +440,10 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setCustomField("SPARK-43923") .build())), - proto.Command + None + ), + ( + proto.Command .newBuilder() .setWriteStreamOperationStart( proto.WriteStreamOperationStart @@ -365,7 +464,10 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath) .setPath("test-path") .build()), - proto.Command + None + ), + ( + proto.Command .newBuilder() .setStreamingQueryCommand( proto.StreamingQueryCommand @@ -377,12 +479,18 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setRunId(DEFAULT_UUID.toString) .build()) .setStop(true)), - proto.Command + None + ), + ( + proto.Command .newBuilder() .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand .newBuilder() .setListListeners(true)), - proto.Command + None + ), + ( + proto.Command .newBuilder() .setRegisterFunction( proto.CommonInlineUserDefinedFunction @@ -395,7 +503,11 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) .setCommand(ByteString.copyFrom("command".getBytes())) .setPythonVer("3.10") - .build())))) { command => + .build())), + None + ) + ) + ) { case (command, producedNumRows) => val sessionId = UUID.randomUUID.toString() withCommandTest(sessionId) { verifyEvents => val instance = new SparkConnectService(false) @@ -435,7 +547,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with done = true } }) - verifyEvents.onCompleted() + verifyEvents.onCompleted(producedNumRows) // The current implementation is expected to be blocking. // This is here to make sure it is. assert(done) From 8dc4c91ce4c0f72785fe85588654e0833696d012 Mon Sep 17 00:00:00 2001 From: Lingkai Kong Date: Tue, 15 Aug 2023 10:25:57 -0400 Subject: [PATCH 3/4] update --- .../execution/SparkConnectPlanExecution.scala | 2 +- .../service/ExecuteEventsManager.scala | 18 +- .../planner/SparkConnectServiceSuite.scala | 203 ++++++++---------- .../service/ExecuteEventsManagerSuite.scala | 3 +- 4 files changed, 108 insertions(+), 118 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index b03f4e3a0285..d23a0fc54152 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -165,7 +165,6 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) // Collect errors and propagate them to the main thread. .andThen { case Success(_) => - executePlan.eventsManager.postFinished(Some(totalNumRows)) case Failure(throwable) => signal.synchronized { error = Some(throwable) @@ -202,6 +201,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) currentPartitionId += 1 } ThreadUtils.awaitReady(future, Duration.Inf) + executePlan.eventsManager.postFinished(Some(totalNumRows)) } else { executePlan.eventsManager.postFinished(Some(totalNumRows)) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index c1293e30c1ea..9c80e64b3fc5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -99,8 +99,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { /** * @return - * How many rows the Connect request has produced @link - * org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished + * How many rows the Connect request has produced @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished */ private[connect] def getProducedRowCount: Option[Long] = producedRowCount @@ -201,12 +201,15 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { /** * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished. + * @param producedRowsCountOpt + * Number of rows that are returned to the user. None is expected when the operation does not + * return any rows. */ - def postFinished(totalNumRowsOpt: Option[Long] = None): Unit = { + def postFinished(producedRowsCountOpt: Option[Long] = None): Unit = { assertStatus( List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution), ExecuteStatus.Finished) - producedRowCount = totalNumRowsOpt + producedRowCount = producedRowsCountOpt listenerBus .post( @@ -214,9 +217,7 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { jobTag, operationId, clock.getTimeMillis(), - producedRowCount - ) - ) + producedRowCount)) } /** @@ -415,6 +416,9 @@ case class SparkListenerConnectOperationFailed( * The time in ms when the event was generated. * @param extraTags: * Additional metadata during the request. + * @param producedRowCount: + * Number of rows that are returned to the user. None is expected when the operation does not + * return any rows. */ case class SparkListenerConnectOperationFinished( jobTag: String, diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 400d8a3ebf1c..74649e15e9eb 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -58,10 +58,10 @@ import org.apache.spark.util.Utils * Testing Connect Service implementation. */ class SparkConnectServiceSuite - extends SharedSparkSession + extends SharedSparkSession with MockitoSugar with Logging - with SparkConnectPlanTest{ + with SparkConnectPlanTest { private def sparkSessionHolder = SessionHolder.forTesting(spark) private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093") @@ -268,9 +268,7 @@ class SparkConnectServiceSuite val localRelation = createLocalRelationProto(schema, inputRows) val plan = proto.Plan .newBuilder() - .setRoot( - localRelation - ) + .setRoot(localRelation) .build() val request = proto.ExecutePlanRequest @@ -382,132 +380,121 @@ class SparkConnectServiceSuite Seq( ( proto.Command - .newBuilder() - .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), - Some(0L) - ), + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), + Some(0L)), ( proto.Command - .newBuilder() - .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show databases").build()), - Some(1L) - ), + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show databases").build()), + Some(1L)), ( proto.Command - .newBuilder() - .setWriteOperation( - proto.WriteOperation - .newBuilder() - .setInput( - proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) - .setPath(Utils.createTempDir().getAbsolutePath) - .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), - None - ), + .newBuilder() + .setWriteOperation( + proto.WriteOperation + .newBuilder() + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) + .setPath(Utils.createTempDir().getAbsolutePath) + .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), + None), ( proto.Command - .newBuilder() - .setWriteOperationV2( - proto.WriteOperationV2 - .newBuilder() - .setInput(proto.Relation.newBuilder.setRange( - proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) - .setTableName("testcat.testtable") - .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), - None - ), + .newBuilder() + .setWriteOperationV2( + proto.WriteOperationV2 + .newBuilder() + .setInput(proto.Relation.newBuilder.setRange( + proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) + .setTableName("testcat.testtable") + .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), + None), ( proto.Command - .newBuilder() - .setCreateDataframeView( - CreateDataFrameViewCommand - .newBuilder() - .setName("testview") - .setInput( - proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), - None - ), - (proto.Command - .newBuilder() - .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), + .newBuilder() + .setCreateDataframeView( + CreateDataFrameViewCommand + .newBuilder() + .setName("testview") + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), None), ( proto.Command - .newBuilder() - .setExtension( - protobuf.Any.pack( - proto.ExamplePluginCommand - .newBuilder() - .setCustomField("SPARK-43923") - .build())), - None - ), + .newBuilder() + .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), + None), ( proto.Command - .newBuilder() - .setWriteStreamOperationStart( - proto.WriteStreamOperationStart - .newBuilder() - .setInput( - proto.Relation + .newBuilder() + .setExtension( + protobuf.Any.pack( + proto.ExamplePluginCommand .newBuilder() - .setRead(proto.Read + .setCustomField("SPARK-43923") + .build())), + None), + ( + proto.Command + .newBuilder() + .setWriteStreamOperationStart( + proto.WriteStreamOperationStart + .newBuilder() + .setInput( + proto.Relation .newBuilder() - .setIsStreaming(true) - .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .setRead(proto.Read + .newBuilder() + .setIsStreaming(true) + .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .build()) .build()) - .build()) - .setOutputMode("Append") - .setAvailableNow(true) - .setQueryName("test") - .setFormat("memory") - .putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath) - .setPath("test-path") - .build()), - None - ), + .setOutputMode("Append") + .setAvailableNow(true) + .setQueryName("test") + .setFormat("memory") + .putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath) + .setPath("test-path") + .build()), + None), ( proto.Command - .newBuilder() - .setStreamingQueryCommand( - proto.StreamingQueryCommand - .newBuilder() - .setQueryId( - proto.StreamingQueryInstanceId - .newBuilder() - .setId(DEFAULT_UUID.toString) - .setRunId(DEFAULT_UUID.toString) - .build()) - .setStop(true)), - None - ), + .newBuilder() + .setStreamingQueryCommand( + proto.StreamingQueryCommand + .newBuilder() + .setQueryId( + proto.StreamingQueryInstanceId + .newBuilder() + .setId(DEFAULT_UUID.toString) + .setRunId(DEFAULT_UUID.toString) + .build()) + .setStop(true)), + None), ( proto.Command - .newBuilder() - .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand .newBuilder() - .setListListeners(true)), - None - ), + .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand + .newBuilder() + .setListListeners(true)), + None), ( proto.Command - .newBuilder() - .setRegisterFunction( - proto.CommonInlineUserDefinedFunction - .newBuilder() - .setFunctionName("function") - .setPythonUdf( - proto.PythonUDF - .newBuilder() - .setEvalType(100) - .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) - .setCommand(ByteString.copyFrom("command".getBytes())) - .setPythonVer("3.10") - .build())), - None - ) - ) - ) { case (command, producedNumRows) => + .newBuilder() + .setRegisterFunction( + proto.CommonInlineUserDefinedFunction + .newBuilder() + .setFunctionName("function") + .setPythonUdf( + proto.PythonUDF + .newBuilder() + .setEvalType(100) + .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) + .setCommand(ByteString.copyFrom("command".getBytes())) + .setPythonVer("3.10") + .build())), + None))) { case (command, producedNumRows) => val sessionId = UUID.randomUUID.toString() withCommandTest(sessionId) { verifyEvents => val instance = new SparkConnectService(false) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index 1a3fea390faa..7950f9c5474f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -150,8 +150,7 @@ class ExecuteEventsManagerSuite events.executeHolder.jobTag, DEFAULT_QUERY_ID, DEFAULT_CLOCK.getTimeMillis(), - Some(100) - )) + Some(100))) } test("SPARK-43923: post closed") { From 6a7cee97fa282e78ff2810c62ad9989e4cc874dd Mon Sep 17 00:00:00 2001 From: Lingkai Kong Date: Tue, 15 Aug 2023 12:22:07 -0400 Subject: [PATCH 4/4] small update --- .../sql/connect/execution/SparkConnectPlanExecution.scala | 2 +- .../spark/sql/connect/service/ExecuteEventsManager.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index d23a0fc54152..00fec4378c57 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -164,7 +164,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) resultFunc = () => ()) // Collect errors and propagate them to the main thread. .andThen { - case Success(_) => + case Success(_) => // do nothing case Failure(throwable) => signal.synchronized { error = Some(throwable) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index 9c80e64b3fc5..5b9267a96793 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -414,11 +414,11 @@ case class SparkListenerConnectOperationFailed( * 36 characters UUID assigned by Connect during a request. * @param eventTime: * The time in ms when the event was generated. - * @param extraTags: - * Additional metadata during the request. * @param producedRowCount: * Number of rows that are returned to the user. None is expected when the operation does not * return any rows. + * @param extraTags: + * Additional metadata during the request. */ case class SparkListenerConnectOperationFinished( jobTag: String,