From 0e4fb0407800a8891b6b3edc47ebda1ad550b394 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Tue, 12 Sep 2023 12:03:13 +0200 Subject: [PATCH 1/2] fix --- .../execution/SparkConnectPlanExecution.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 00fec4378c57..c1c02640a4a5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -110,7 +110,6 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) errorOnDuplicatedFieldNames = false) var numSent = 0 - var totalNumRows: Long = 0 def sendBatch(bytes: Array[Byte], count: Long): Unit = { val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId) val batch = proto.ExecutePlanResponse.ArrowBatch @@ -121,15 +120,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) response.setArrowBatch(batch) responseObserver.onNext(response.build()) numSent += 1 - totalNumRows += count } dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => + executePlan.eventsManager.postFinished(Some(rows.length)) converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } - executePlan.eventsManager.postFinished(Some(totalNumRows)) case _ => SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { val rows = dataframe.queryExecution.executedPlan.execute() @@ -142,6 +140,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val signal = new Object val partitions = new Array[Array[Batch]](numPartitions) + val numFinishedPartitions = 0 + var totalNumRows: Long = 0 var error: Option[Throwable] = None // This callback is executed by the DAGScheduler thread. @@ -150,6 +150,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val resultHandler = (partitionId: Int, partition: Array[Batch]) => { signal.synchronized { partitions(partitionId) = partition + totalNumRows += partition._2 + numFinishedPartitions += 1 + if (numFinishedParttions == numPartitions) { + // Execution is finished, when all partitions returned results. + executePlan.eventsManager.postFinished(Some(totalNumRows)) + } signal.notify() } () @@ -201,9 +207,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) currentPartitionId += 1 } ThreadUtils.awaitReady(future, Duration.Inf) - executePlan.eventsManager.postFinished(Some(totalNumRows)) } else { - executePlan.eventsManager.postFinished(Some(totalNumRows)) + executePlan.eventsManager.postFinished(Some(0)) } } } From f12ab2fa7ebd8bdc97fa7a04c5a50d156960cd33 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Tue, 12 Sep 2023 14:42:22 +0200 Subject: [PATCH 2/2] typos --- .../sql/connect/execution/SparkConnectPlanExecution.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index c1c02640a4a5..1ad11490c350 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -140,7 +140,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val signal = new Object val partitions = new Array[Array[Batch]](numPartitions) - val numFinishedPartitions = 0 + var numFinishedPartitions = 0 var totalNumRows: Long = 0 var error: Option[Throwable] = None @@ -150,9 +150,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val resultHandler = (partitionId: Int, partition: Array[Batch]) => { signal.synchronized { partitions(partitionId) = partition - totalNumRows += partition._2 + totalNumRows += partition.map(_._2).sum numFinishedPartitions += 1 - if (numFinishedParttions == numPartitions) { + if (numFinishedPartitions == numPartitions) { // Execution is finished, when all partitions returned results. executePlan.eventsManager.postFinished(Some(totalNumRows)) }