diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 00fec4378c57..1ad11490c350 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -110,7 +110,6 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) errorOnDuplicatedFieldNames = false) var numSent = 0 - var totalNumRows: Long = 0 def sendBatch(bytes: Array[Byte], count: Long): Unit = { val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId) val batch = proto.ExecutePlanResponse.ArrowBatch @@ -121,15 +120,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) response.setArrowBatch(batch) responseObserver.onNext(response.build()) numSent += 1 - totalNumRows += count } dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => + executePlan.eventsManager.postFinished(Some(rows.length)) converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } - executePlan.eventsManager.postFinished(Some(totalNumRows)) case _ => SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { val rows = dataframe.queryExecution.executedPlan.execute() @@ -142,6 +140,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val signal = new Object val partitions = new Array[Array[Batch]](numPartitions) + var numFinishedPartitions = 0 + var totalNumRows: Long = 0 var error: Option[Throwable] = None // This callback is executed by the DAGScheduler thread. @@ -150,6 +150,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val resultHandler = (partitionId: Int, partition: Array[Batch]) => { signal.synchronized { partitions(partitionId) = partition + totalNumRows += partition.map(_._2).sum + numFinishedPartitions += 1 + if (numFinishedPartitions == numPartitions) { + // Execution is finished, when all partitions returned results. + executePlan.eventsManager.postFinished(Some(totalNumRows)) + } signal.notify() } () @@ -201,9 +207,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) currentPartitionId += 1 } ThreadUtils.awaitReady(future, Duration.Inf) - executePlan.eventsManager.postFinished(Some(totalNumRows)) } else { - executePlan.eventsManager.postFinished(Some(totalNumRows)) + executePlan.eventsManager.postFinished(Some(0)) } } }