From cfb012bb5f1c78088a1706f6e4289bfee3467ba8 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 2 Nov 2022 19:54:46 +0800 Subject: [PATCH 1/4] arrow 101 ffix fix lint address comments nit address comments --- .../main/protobuf/spark/connect/base.proto | 17 +-- .../service/SparkConnectStreamHandler.scala | 112 ++++++++++++++++-- python/pyspark/sql/connect/client.py | 23 +++- python/pyspark/sql/connect/proto/base_pb2.py | 36 +++--- python/pyspark/sql/connect/proto/base_pb2.pyi | 34 +++--- .../sql/tests/connect/test_connect_basic.py | 11 ++ .../sql/execution/arrow/ArrowConverters.scala | 93 ++++++++++++++- 7 files changed, 271 insertions(+), 55 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index b376515bf1af0..bc8905c8dea94 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -69,7 +69,7 @@ message Response { // Result type oneof result_type { - ArrowBatch batch = 2; + ArrowBatch arrow_batch = 2; JSONBatch json_batch = 3; } @@ -79,19 +79,20 @@ message Response { // Batch results of metrics. message ArrowBatch { - int64 row_count = 1; - int64 uncompressed_bytes = 2; - int64 compressed_bytes = 3; - bytes data = 4; - bytes schema = 5; + int64 batch_id = 1; + int64 row_count = 2; + int64 uncompressed_bytes = 3; + int64 compressed_bytes = 4; + bytes data = 5; } // Message type when the result is returned as JSON. This is essentially a bulk wrapper // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format // of `{col -> row}`. message JSONBatch { - int64 row_count = 1; - bytes data = 2; + int64 batch_id = 1; + int64 row_count = 2; + bytes data = 3; } message Metrics { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 58fc6237867c0..52ebd999b1f20 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.connect.service +import java.util.concurrent.Future + import scala.collection.JavaConverters._ +import scala.util.Try import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -29,8 +32,11 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { @@ -48,19 +54,26 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } } - def handlePlan(session: SparkSession, request: proto.Request): Unit = { + def handlePlan(session: SparkSession, request: Request): Unit = { // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(request.getPlan.getRoot, session) - val rows = - Dataset.ofRows(session, planner.transform()) - processRows(request.getClientId, rows) + val dataframe = Dataset.ofRows(session, planner.transform()) + // check whether all data types are supported + if (Try { + ArrowUtils.toArrowSchema(dataframe.schema, session.sessionState.conf.sessionLocalTimeZone) + }.isSuccess) { + processRowsAsArrowBatches(request.getClientId, dataframe) + } else { + processRowsAsJsonBatches(request.getClientId, dataframe) + } } - def processRows(clientId: String, rows: DataFrame): Unit = { + def processRowsAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { // Only process up to 10MB of data. val sb = new StringBuilder + var batchId = 0L var rowCount = 0 - rows.toJSON + dataframe.toJSON .collect() .foreach(row => { @@ -83,12 +96,14 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.JSONBatch .newBuilder() + .setBatchId(batchId) .setData(ByteString.copyFromUtf8(sb.toString())) .setRowCount(rowCount) .build() response.setJsonBatch(batch) responseObserver.onNext(response.build()) sb.clear() + batchId += 1 sb.append(row) rowCount = 1 } else { @@ -114,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte responseObserver.onNext(response.build()) } - responseObserver.onNext(sendMetricsToResponse(clientId, rows)) + responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) responseObserver.onCompleted() } + def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { + val spark = dataframe.sparkSession + val schema = dataframe.schema + // TODO: control the batch size instead of max records + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + + SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { + val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow") + val tasks = collection.mutable.ArrayBuffer.empty[Future[_]] + val rows = dataframe.queryExecution.executedPlan.execute() + + if (rows.getNumPartitions > 0) { + val batches = rows.mapPartitionsInternal { iter => + ArrowConverters + .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) + } + + val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray + + val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => { + if (taskResult.exists(_._1.nonEmpty)) { + // only send non-empty partitions + val task = pool.submit(new Runnable { + override def run(): Unit = { + var batchId = partitionId.toLong << 33 + taskResult.foreach { case (bytes, count, size) => + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setBatchId(batchId) + .setRowCount(count) + .setUncompressedBytes(size) + .setCompressedBytes(bytes.length) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + batchId += 1 + } + } + }) + tasks.synchronized { + tasks.append(task) + } + } + val i = 0 // Unit + } + + spark.sparkContext.runJob(batches, processPartition, resultHandler) + } + + // make sure at least 1 batch will be sent + if (tasks.isEmpty) { + val task = pool.submit(new Runnable { + override def run(): Unit = { + val (bytes, count, size) = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setBatchId(0L) + .setRowCount(count) + .setUncompressedBytes(size) + .setCompressedBytes(bytes.length) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + } + }) + tasks.append(task) + } + + tasks.foreach(_.get()) + pool.shutdown() + + responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) + responseObserver.onCompleted() + } + } + def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = { // Send a last batch with the metrics Response diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 2eba9ac11f525..09ab4bccb912b 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -374,14 +374,14 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) - def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: + def _process_batch(self, b: pb2.Response) -> Optional[tuple[int, pandas.DataFrame]]: import pandas as pd - if b.batch is not None and len(b.batch.data) > 0: - with pa.ipc.open_stream(b.batch.data) as rd: - return rd.read_pandas() + if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: + with pa.ipc.open_stream(b.arrow_batch.data) as rd: + return (b.arrow_batch.batch_id, rd.read_pandas()) elif b.json_batch is not None and len(b.json_batch.data) > 0: - return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) + return (b.json_batch.batch_id, pd.read_json(io.BytesIO(b.json_batch.data), lines=True)) return None def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]: @@ -399,7 +399,18 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra result_dfs.append(pb) if len(result_dfs) > 0: - df = pd.concat(result_dfs) + # sort by batch id + result_dfs.sort(key=lambda t: t[0]) + # concat the pandas dataframes + df = pd.concat([t[1] for t in result_dfs]) + del result_dfs + + # pd.concat generates non-consecutive index like: + # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') + # set it to RangeIndex to be consistent with pyspark + n = len(df) + df = df.set_index(pd.RangeIndex(start=0, stop=n, step=1)) + # Attach the metrics to the DataFrame attributes. if m is not None: df.attrs["metrics"] = self._build_metrics(m) diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index eb9ecc9157f2c..ee3be242e2bc0 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf1\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xb2\x01\n\nArrowBatch\x12\x19\n\x08\x62\x61tch_id\x18\x01 \x01(\x03R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x02 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x03 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x04 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x05 \x01(\x0cR\x04\x64\x61ta\x1aW\n\tJSONBatch\x12\x19\n\x08\x62\x61tch_id\x18\x01 \x01(\x03R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x02 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -53,21 +53,21 @@ _REQUEST_USERCONTEXT._serialized_start = 429 _REQUEST_USERCONTEXT._serialized_end = 551 _RESPONSE._serialized_start = 554 - _RESPONSE._serialized_end = 1522 - _RESPONSE_ARROWBATCH._serialized_start = 783 - _RESPONSE_ARROWBATCH._serialized_end = 958 - _RESPONSE_JSONBATCH._serialized_start = 960 - _RESPONSE_JSONBATCH._serialized_end = 1020 - _RESPONSE_METRICS._serialized_start = 1023 - _RESPONSE_METRICS._serialized_end = 1507 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1305 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507 - _ANALYZERESPONSE._serialized_start = 1525 - _ANALYZERESPONSE._serialized_end = 1659 - _SPARKCONNECTSERVICE._serialized_start = 1662 - _SPARKCONNECTSERVICE._serialized_end = 1824 + _RESPONSE._serialized_end = 1563 + _RESPONSE_ARROWBATCH._serialized_start = 794 + _RESPONSE_ARROWBATCH._serialized_end = 972 + _RESPONSE_JSONBATCH._serialized_start = 974 + _RESPONSE_JSONBATCH._serialized_end = 1061 + _RESPONSE_METRICS._serialized_start = 1064 + _RESPONSE_METRICS._serialized_end = 1548 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1148 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1458 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1346 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1458 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1460 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1548 + _ANALYZERESPONSE._serialized_start = 1566 + _ANALYZERESPONSE._serialized_end = 1700 + _SPARKCONNECTSERVICE._serialized_start = 1703 + _SPARKCONNECTSERVICE._serialized_end = 1865 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 5ffd7701b440d..d3bf65758fab5 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -177,36 +177,36 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + BATCH_ID_FIELD_NUMBER: builtins.int ROW_COUNT_FIELD_NUMBER: builtins.int UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int COMPRESSED_BYTES_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int - SCHEMA_FIELD_NUMBER: builtins.int + batch_id: builtins.int row_count: builtins.int uncompressed_bytes: builtins.int compressed_bytes: builtins.int data: builtins.bytes - schema: builtins.bytes def __init__( self, *, + batch_id: builtins.int = ..., row_count: builtins.int = ..., uncompressed_bytes: builtins.int = ..., compressed_bytes: builtins.int = ..., data: builtins.bytes = ..., - schema: builtins.bytes = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ + "batch_id", + b"batch_id", "compressed_bytes", b"compressed_bytes", "data", b"data", "row_count", b"row_count", - "schema", - b"schema", "uncompressed_bytes", b"uncompressed_bytes", ], @@ -220,18 +220,24 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + BATCH_ID_FIELD_NUMBER: builtins.int ROW_COUNT_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int + batch_id: builtins.int row_count: builtins.int data: builtins.bytes def __init__( self, *, + batch_id: builtins.int = ..., row_count: builtins.int = ..., data: builtins.bytes = ..., ) -> None: ... def ClearField( - self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] + self, + field_name: typing_extensions.Literal[ + "batch_id", b"batch_id", "data", b"data", "row_count", b"row_count" + ], ) -> None: ... class Metrics(google.protobuf.message.Message): @@ -339,12 +345,12 @@ class Response(google.protobuf.message.Message): ) -> None: ... CLIENT_ID_FIELD_NUMBER: builtins.int - BATCH_FIELD_NUMBER: builtins.int + ARROW_BATCH_FIELD_NUMBER: builtins.int JSON_BATCH_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property - def batch(self) -> global___Response.ArrowBatch: ... + def arrow_batch(self) -> global___Response.ArrowBatch: ... @property def json_batch(self) -> global___Response.JSONBatch: ... @property @@ -356,15 +362,15 @@ class Response(google.protobuf.message.Message): self, *, client_id: builtins.str = ..., - batch: global___Response.ArrowBatch | None = ..., + arrow_batch: global___Response.ArrowBatch | None = ..., json_batch: global___Response.JSONBatch | None = ..., metrics: global___Response.Metrics | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "batch", - b"batch", + "arrow_batch", + b"arrow_batch", "json_batch", b"json_batch", "metrics", @@ -376,8 +382,8 @@ class Response(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ - "batch", - b"batch", + "arrow_batch", + b"arrow_batch", "client_id", b"client_id", "json_batch", @@ -390,7 +396,7 @@ class Response(google.protobuf.message.Message): ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] - ) -> typing_extensions.Literal["batch", "json_batch"] | None: ... + ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ... global___Response = Response diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a0f046907f73e..866d821215b40 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -197,6 +197,17 @@ def test_range(self): .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas()) ) + def test_empty_dataset(self): + self.assertTrue( + self.connect.sql("SELECT 1 AS X LIMIT 0") + .toPandas() + .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas()) + ) + pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas() + self.assertEqual(0, len(pdf)) # empty dataset + self.assertEqual(1, len(pdf.columns)) # one column + self.assertEqual("X", pdf.columns[0]) + def test_simple_datasource_read(self) -> None: writeDf = self.df_text tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bded158645cce..896e80439bf11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.{ByteBufferOutputStream, Utils} +import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} /** @@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging { } } + private[sql] def toArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int, + timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "toArrowBatchIterator", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) + val arrowWriter = ArrowWriter.create(root) + + Option(TaskContext.get).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } + + new Iterator[(Array[Byte], Long, Long)] { + + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + false + } + + override def next(): (Array[Byte], Long, Long) = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + var rowCount = 0L + var estimatedSize = SizeEstimator.estimate(arrowSchema) + + SizeEstimator.estimate(IpcOption.DEFAULT) + Utils.tryWithSafeFinally { + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + estimatedSize += SizeEstimator.estimate(row) + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + + MessageSerializer.serialize(writeChannel, arrowSchema) + MessageSerializer.serialize(writeChannel, batch) + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + (out.toByteArray, rowCount, estimatedSize) + } + } + } + + private[sql] def createEmptyArrowBatch( + schema: StructType, + timeZoneId: String): (Array[Byte], Long, Long) = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "createEmptyArrowBatch", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) + val arrowWriter = ArrowWriter.create(root) + + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + val estimatedSize = SizeEstimator.estimate(arrowSchema) + + SizeEstimator.estimate(IpcOption.DEFAULT) + Utils.tryWithSafeFinally { + arrowWriter.finish() + val batch = unloader.getRecordBatch() // empty batch + + MessageSerializer.serialize(writeChannel, arrowSchema) + MessageSerializer.serialize(writeChannel, batch) + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + (out.toByteArray, 0L, estimatedSize) + } + /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ From a83505ca1d0fe4e04e5830c8056efb7f950c92a4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 8 Nov 2022 18:28:12 +0800 Subject: [PATCH 2/4] address comments --- .../main/protobuf/spark/connect/base.proto | 18 +-- .../service/SparkConnectStreamHandler.scala | 112 +++++++++--------- python/pyspark/sql/connect/client.py | 16 ++- python/pyspark/sql/connect/proto/base_pb2.py | 34 +++--- python/pyspark/sql/connect/proto/base_pb2.pyi | 17 ++- 5 files changed, 112 insertions(+), 85 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index bc8905c8dea94..740d278a62635 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -79,20 +79,22 @@ message Response { // Batch results of metrics. message ArrowBatch { - int64 batch_id = 1; - int64 row_count = 2; - int64 uncompressed_bytes = 3; - int64 compressed_bytes = 4; - bytes data = 5; + int32 partition_id = 1; + int32 batch_id = 2; + int64 row_count = 3; + int64 uncompressed_bytes = 4; + int64 compressed_bytes = 5; + bytes data = 6; } // Message type when the result is returned as JSON. This is essentially a bulk wrapper // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format // of `{col -> row}`. message JSONBatch { - int64 batch_id = 1; - int64 row_count = 2; - bytes data = 3; + int32 partition_id = 1; + int32 batch_id = 2; + int64 row_count = 3; + bytes data = 4; } message Metrics { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 52ebd999b1f20..076c6683277e1 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.Future - import scala.collection.JavaConverters._ import scala.util.Try @@ -36,7 +34,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { @@ -71,7 +68,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def processRowsAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { // Only process up to 10MB of data. val sb = new StringBuilder - var batchId = 0L + var batchId = 0 var rowCount = 0 dataframe.toJSON .collect() @@ -96,6 +93,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.JSONBatch .newBuilder() + .setPartitionId(-1) .setBatchId(batchId) .setData(ByteString.copyFromUtf8(sb.toString())) .setRowCount(rowCount) @@ -141,73 +139,81 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { - val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow") - val tasks = collection.mutable.ArrayBuffer.empty[Future[_]] val rows = dataframe.queryExecution.executedPlan.execute() + val numPartitions = rows.getNumPartitions + var numSent = 0 - if (rows.getNumPartitions > 0) { + if (numPartitions > 0) { val batches = rows.mapPartitionsInternal { iter => ArrowConverters .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) } + val signal = new Object + val queue = collection.mutable.Queue.empty[(Int, Array[(Array[Byte], Long, Long)])] + val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray - val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => { - if (taskResult.exists(_._1.nonEmpty)) { - // only send non-empty partitions - val task = pool.submit(new Runnable { - override def run(): Unit = { - var batchId = partitionId.toLong << 33 - taskResult.foreach { case (bytes, count, size) => - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.ArrowBatch - .newBuilder() - .setBatchId(batchId) - .setRowCount(count) - .setUncompressedBytes(size) - .setCompressedBytes(bytes.length) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - batchId += 1 - } - } - }) - tasks.synchronized { - tasks.append(task) - } + val resultHandler = (partitionId: Int, partition: Array[(Array[Byte], Long, Long)]) => { + signal.synchronized { + queue.enqueue((partitionId, partition)) + signal.notify() } val i = 0 // Unit } spark.sparkContext.runJob(batches, processPartition, resultHandler) - } - // make sure at least 1 batch will be sent - if (tasks.isEmpty) { - val task = pool.submit(new Runnable { - override def run(): Unit = { - val (bytes, count, size) = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.ArrowBatch - .newBuilder() - .setBatchId(0L) - .setRowCount(count) - .setUncompressedBytes(size) - .setCompressedBytes(bytes.length) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) + var numHandled = 0 + while (numHandled < numPartitions) { + val (partitionId, partition) = signal.synchronized { + while (queue.isEmpty) { + signal.wait() + } + queue.dequeue() } - }) - tasks.append(task) + + // only send non-empty partitions + if (partition.exists(_._1.nonEmpty)) { + var batchId = 0 + partition.foreach { case (bytes, count, size) => + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setPartitionId(partitionId) + .setBatchId(batchId) + .setRowCount(count) + .setUncompressedBytes(size) + .setCompressedBytes(bytes.length) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + batchId += 1 + } + numSent += 1 + } + + numHandled += 1 + } } - tasks.foreach(_.get()) - pool.shutdown() + // make sure at least 1 batch will be sent + if (numSent == 0) { + val (bytes, count, size) = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setPartitionId(-1) + .setBatchId(0) + .setRowCount(count) + .setUncompressedBytes(size) + .setCompressedBytes(bytes.length) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + } responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) responseObserver.onCompleted() diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 09ab4bccb912b..0340dfcb19df3 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -374,14 +374,18 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) - def _process_batch(self, b: pb2.Response) -> Optional[tuple[int, pandas.DataFrame]]: + def _process_batch(self, b: pb2.Response) -> Optional[tuple[int, int, pandas.DataFrame]]: import pandas as pd if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: with pa.ipc.open_stream(b.arrow_batch.data) as rd: - return (b.arrow_batch.batch_id, rd.read_pandas()) + return (b.arrow_batch.partition_id, b.arrow_batch.batch_id, rd.read_pandas()) elif b.json_batch is not None and len(b.json_batch.data) > 0: - return (b.json_batch.batch_id, pd.read_json(io.BytesIO(b.json_batch.data), lines=True)) + return ( + b.json_batch.partition_id, + b.json_batch.batch_id, + pd.read_json(io.BytesIO(b.json_batch.data), lines=True), + ) return None def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]: @@ -399,10 +403,10 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra result_dfs.append(pb) if len(result_dfs) > 0: - # sort by batch id - result_dfs.sort(key=lambda t: t[0]) + # sort by (partition_id, batch_id) + result_dfs.sort(key=lambda t: (t[0], t[1])) # concat the pandas dataframes - df = pd.concat([t[1] for t in result_dfs]) + df = pd.concat([t[2] for t in result_dfs]) del result_dfs # pd.concat generates non-consecutive index like: diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index ee3be242e2bc0..55fc153f974f9 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf1\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xb2\x01\n\nArrowBatch\x12\x19\n\x08\x62\x61tch_id\x18\x01 \x01(\x03R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x02 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x03 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x04 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x05 \x01(\x0cR\x04\x64\x61ta\x1aW\n\tJSONBatch\x12\x19\n\x08\x62\x61tch_id\x18\x01 \x01(\x03R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x02 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xb7\x08\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xd5\x01\n\nArrowBatch\x12!\n\x0cpartition_id\x18\x01 \x01(\x05R\x0bpartitionId\x12\x19\n\x08\x62\x61tch_id\x18\x02 \x01(\x05R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x03 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x04 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x05 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x06 \x01(\x0cR\x04\x64\x61ta\x1az\n\tJSONBatch\x12!\n\x0cpartition_id\x18\x01 \x01(\x05R\x0bpartitionId\x12\x19\n\x08\x62\x61tch_id\x18\x02 \x01(\x05R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x03 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -53,21 +53,21 @@ _REQUEST_USERCONTEXT._serialized_start = 429 _REQUEST_USERCONTEXT._serialized_end = 551 _RESPONSE._serialized_start = 554 - _RESPONSE._serialized_end = 1563 + _RESPONSE._serialized_end = 1633 _RESPONSE_ARROWBATCH._serialized_start = 794 - _RESPONSE_ARROWBATCH._serialized_end = 972 - _RESPONSE_JSONBATCH._serialized_start = 974 - _RESPONSE_JSONBATCH._serialized_end = 1061 - _RESPONSE_METRICS._serialized_start = 1064 - _RESPONSE_METRICS._serialized_end = 1548 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1148 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1458 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1346 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1458 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1460 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1548 - _ANALYZERESPONSE._serialized_start = 1566 - _ANALYZERESPONSE._serialized_end = 1700 - _SPARKCONNECTSERVICE._serialized_start = 1703 - _SPARKCONNECTSERVICE._serialized_end = 1865 + _RESPONSE_ARROWBATCH._serialized_end = 1007 + _RESPONSE_JSONBATCH._serialized_start = 1009 + _RESPONSE_JSONBATCH._serialized_end = 1131 + _RESPONSE_METRICS._serialized_start = 1134 + _RESPONSE_METRICS._serialized_end = 1618 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1218 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1528 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1416 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1528 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1530 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1618 + _ANALYZERESPONSE._serialized_start = 1636 + _ANALYZERESPONSE._serialized_end = 1770 + _SPARKCONNECTSERVICE._serialized_start = 1773 + _SPARKCONNECTSERVICE._serialized_end = 1935 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index d3bf65758fab5..866e0e04bf11f 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -177,11 +177,13 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + PARTITION_ID_FIELD_NUMBER: builtins.int BATCH_ID_FIELD_NUMBER: builtins.int ROW_COUNT_FIELD_NUMBER: builtins.int UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int COMPRESSED_BYTES_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int + partition_id: builtins.int batch_id: builtins.int row_count: builtins.int uncompressed_bytes: builtins.int @@ -190,6 +192,7 @@ class Response(google.protobuf.message.Message): def __init__( self, *, + partition_id: builtins.int = ..., batch_id: builtins.int = ..., row_count: builtins.int = ..., uncompressed_bytes: builtins.int = ..., @@ -205,6 +208,8 @@ class Response(google.protobuf.message.Message): b"compressed_bytes", "data", b"data", + "partition_id", + b"partition_id", "row_count", b"row_count", "uncompressed_bytes", @@ -220,15 +225,18 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + PARTITION_ID_FIELD_NUMBER: builtins.int BATCH_ID_FIELD_NUMBER: builtins.int ROW_COUNT_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int + partition_id: builtins.int batch_id: builtins.int row_count: builtins.int data: builtins.bytes def __init__( self, *, + partition_id: builtins.int = ..., batch_id: builtins.int = ..., row_count: builtins.int = ..., data: builtins.bytes = ..., @@ -236,7 +244,14 @@ class Response(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ - "batch_id", b"batch_id", "data", b"data", "row_count", b"row_count" + "batch_id", + b"batch_id", + "data", + b"data", + "partition_id", + b"partition_id", + "row_count", + b"row_count", ], ) -> None: ... From 335519bcab88663fc81724fe1f3528ce7d49a14d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 9 Nov 2022 18:58:34 +0800 Subject: [PATCH 3/4] send the batches in order --- .../main/protobuf/spark/connect/base.proto | 16 ++++----- .../service/SparkConnectStreamHandler.scala | 36 ++++++++----------- python/pyspark/sql/connect/client.py | 17 +++------ python/pyspark/sql/connect/proto/base_pb2.py | 34 +++++++++--------- python/pyspark/sql/connect/proto/base_pb2.pyi | 28 +-------------- 5 files changed, 44 insertions(+), 87 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 740d278a62635..f0e7a102554b4 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -79,22 +79,18 @@ message Response { // Batch results of metrics. message ArrowBatch { - int32 partition_id = 1; - int32 batch_id = 2; - int64 row_count = 3; - int64 uncompressed_bytes = 4; - int64 compressed_bytes = 5; - bytes data = 6; + int64 row_count = 1; + int64 uncompressed_bytes = 2; + int64 compressed_bytes = 3; + bytes data = 4; } // Message type when the result is returned as JSON. This is essentially a bulk wrapper // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format // of `{col -> row}`. message JSONBatch { - int32 partition_id = 1; - int32 batch_id = 2; - int64 row_count = 3; - bytes data = 4; + int64 row_count = 1; + bytes data = 2; } message Metrics { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 076c6683277e1..22d23393ec0db 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -68,7 +68,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def processRowsAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { // Only process up to 10MB of data. val sb = new StringBuilder - var batchId = 0 var rowCount = 0 dataframe.toJSON .collect() @@ -93,15 +92,12 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.JSONBatch .newBuilder() - .setPartitionId(-1) - .setBatchId(batchId) .setData(ByteString.copyFromUtf8(sb.toString())) .setRowCount(rowCount) .build() response.setJsonBatch(batch) responseObserver.onNext(response.build()) sb.clear() - batchId += 1 sb.append(row) rowCount = 1 } else { @@ -144,19 +140,21 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte var numSent = 0 if (numPartitions > 0) { + type Batch = (Array[Byte], Long, Long) + val batches = rows.mapPartitionsInternal { iter => ArrowConverters .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) } val signal = new Object - val queue = collection.mutable.Queue.empty[(Int, Array[(Array[Byte], Long, Long)])] + val partitions = Array.fill[Array[Batch]](numPartitions)(null) - val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray + val processPartition = (iter: Iterator[Batch]) => iter.toArray - val resultHandler = (partitionId: Int, partition: Array[(Array[Byte], Long, Long)]) => { + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { signal.synchronized { - queue.enqueue((partitionId, partition)) + partitions(partitionId) = partition signal.notify() } val i = 0 // Unit @@ -164,24 +162,23 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte spark.sparkContext.runJob(batches, processPartition, resultHandler) - var numHandled = 0 - while (numHandled < numPartitions) { - val (partitionId, partition) = signal.synchronized { - while (queue.isEmpty) { + var currentPartitionId = 0 + while (currentPartitionId < numPartitions) { + val partition = signal.synchronized { + while (partitions(currentPartitionId) == null) { signal.wait() } - queue.dequeue() + val partition = partitions(currentPartitionId) + partitions(currentPartitionId) = null + partition } // only send non-empty partitions - if (partition.exists(_._1.nonEmpty)) { - var batchId = 0 + if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) { partition.foreach { case (bytes, count, size) => val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.ArrowBatch .newBuilder() - .setPartitionId(partitionId) - .setBatchId(batchId) .setRowCount(count) .setUncompressedBytes(size) .setCompressedBytes(bytes.length) @@ -189,12 +186,11 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte .build() response.setArrowBatch(batch) responseObserver.onNext(response.build()) - batchId += 1 } numSent += 1 } - numHandled += 1 + currentPartitionId += 1 } } @@ -204,8 +200,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.ArrowBatch .newBuilder() - .setPartitionId(-1) - .setBatchId(0) .setRowCount(count) .setUncompressedBytes(size) .setCompressedBytes(bytes.length) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 0340dfcb19df3..cc5988e7cf885 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -374,18 +374,14 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) - def _process_batch(self, b: pb2.Response) -> Optional[tuple[int, int, pandas.DataFrame]]: + def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: import pandas as pd if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: with pa.ipc.open_stream(b.arrow_batch.data) as rd: - return (b.arrow_batch.partition_id, b.arrow_batch.batch_id, rd.read_pandas()) + return rd.read_pandas() elif b.json_batch is not None and len(b.json_batch.data) > 0: - return ( - b.json_batch.partition_id, - b.json_batch.batch_id, - pd.read_json(io.BytesIO(b.json_batch.data), lines=True), - ) + return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) return None def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]: @@ -403,17 +399,14 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra result_dfs.append(pb) if len(result_dfs) > 0: - # sort by (partition_id, batch_id) - result_dfs.sort(key=lambda t: (t[0], t[1])) - # concat the pandas dataframes - df = pd.concat([t[2] for t in result_dfs]) + df = pd.concat(result_dfs) del result_dfs # pd.concat generates non-consecutive index like: # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') # set it to RangeIndex to be consistent with pyspark n = len(df) - df = df.set_index(pd.RangeIndex(start=0, stop=n, step=1)) + df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True) # Attach the metrics to the DataFrame attributes. if m is not None: diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 55fc153f974f9..dbe93cbf87f91 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xb7\x08\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xd5\x01\n\nArrowBatch\x12!\n\x0cpartition_id\x18\x01 \x01(\x05R\x0bpartitionId\x12\x19\n\x08\x62\x61tch_id\x18\x02 \x01(\x05R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x03 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x04 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x05 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x06 \x01(\x0cR\x04\x64\x61ta\x1az\n\tJSONBatch\x12!\n\x0cpartition_id\x18\x01 \x01(\x05R\x0bpartitionId\x12\x19\n\x08\x62\x61tch_id\x18\x02 \x01(\x05R\x07\x62\x61tchId\x12\x1b\n\trow_count\x18\x03 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xbb\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\x97\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -53,21 +53,21 @@ _REQUEST_USERCONTEXT._serialized_start = 429 _REQUEST_USERCONTEXT._serialized_end = 551 _RESPONSE._serialized_start = 554 - _RESPONSE._serialized_end = 1633 + _RESPONSE._serialized_end = 1509 _RESPONSE_ARROWBATCH._serialized_start = 794 - _RESPONSE_ARROWBATCH._serialized_end = 1007 - _RESPONSE_JSONBATCH._serialized_start = 1009 - _RESPONSE_JSONBATCH._serialized_end = 1131 - _RESPONSE_METRICS._serialized_start = 1134 - _RESPONSE_METRICS._serialized_end = 1618 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1218 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1528 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1416 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1528 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1530 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1618 - _ANALYZERESPONSE._serialized_start = 1636 - _ANALYZERESPONSE._serialized_end = 1770 - _SPARKCONNECTSERVICE._serialized_start = 1773 - _SPARKCONNECTSERVICE._serialized_end = 1935 + _RESPONSE_ARROWBATCH._serialized_end = 945 + _RESPONSE_JSONBATCH._serialized_start = 947 + _RESPONSE_JSONBATCH._serialized_end = 1007 + _RESPONSE_METRICS._serialized_start = 1010 + _RESPONSE_METRICS._serialized_end = 1494 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1094 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1404 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1292 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1404 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1406 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1494 + _ANALYZERESPONSE._serialized_start = 1512 + _ANALYZERESPONSE._serialized_end = 1646 + _SPARKCONNECTSERVICE._serialized_start = 1649 + _SPARKCONNECTSERVICE._serialized_end = 1811 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 866e0e04bf11f..97ac0dfe8b905 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -177,14 +177,10 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - PARTITION_ID_FIELD_NUMBER: builtins.int - BATCH_ID_FIELD_NUMBER: builtins.int ROW_COUNT_FIELD_NUMBER: builtins.int UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int COMPRESSED_BYTES_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int - partition_id: builtins.int - batch_id: builtins.int row_count: builtins.int uncompressed_bytes: builtins.int compressed_bytes: builtins.int @@ -192,8 +188,6 @@ class Response(google.protobuf.message.Message): def __init__( self, *, - partition_id: builtins.int = ..., - batch_id: builtins.int = ..., row_count: builtins.int = ..., uncompressed_bytes: builtins.int = ..., compressed_bytes: builtins.int = ..., @@ -202,14 +196,10 @@ class Response(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ - "batch_id", - b"batch_id", "compressed_bytes", b"compressed_bytes", "data", b"data", - "partition_id", - b"partition_id", "row_count", b"row_count", "uncompressed_bytes", @@ -225,34 +215,18 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - PARTITION_ID_FIELD_NUMBER: builtins.int - BATCH_ID_FIELD_NUMBER: builtins.int ROW_COUNT_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int - partition_id: builtins.int - batch_id: builtins.int row_count: builtins.int data: builtins.bytes def __init__( self, *, - partition_id: builtins.int = ..., - batch_id: builtins.int = ..., row_count: builtins.int = ..., data: builtins.bytes = ..., ) -> None: ... def ClearField( - self, - field_name: typing_extensions.Literal[ - "batch_id", - b"batch_id", - "data", - b"data", - "partition_id", - b"partition_id", - "row_count", - b"row_count", - ], + self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] ) -> None: ... class Metrics(google.protobuf.message.Message): From e53d5383ee66b770fb0703dd02a6e32ba143dc20 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 10 Nov 2022 18:59:43 +0800 Subject: [PATCH 4/4] address comments --- .../main/protobuf/spark/connect/base.proto | 4 +- .../service/SparkConnectStreamHandler.scala | 67 +++++++++---------- python/pyspark/sql/connect/client.py | 1 - python/pyspark/sql/connect/proto/base_pb2.py | 36 +++++----- python/pyspark/sql/connect/proto/base_pb2.pyi | 18 +---- .../sql/tests/connect/test_connect_basic.py | 1 + .../sql/execution/arrow/ArrowConverters.scala | 19 ++---- 7 files changed, 58 insertions(+), 88 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index f0e7a102554b4..5f59ada38b658 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -80,9 +80,7 @@ message Response { // Batch results of metrics. message ArrowBatch { int64 row_count = 1; - int64 uncompressed_bytes = 2; - int64 compressed_bytes = 3; - bytes data = 4; + bytes data = 2; } // Message type when the result is returned as JSON. This is essentially a bulk wrapper diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 22d23393ec0db..3b734616b2138 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ -import scala.util.Try import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -33,7 +32,6 @@ import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.util.ArrowUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { @@ -55,17 +53,16 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(request.getPlan.getRoot, session) val dataframe = Dataset.ofRows(session, planner.transform()) - // check whether all data types are supported - if (Try { - ArrowUtils.toArrowSchema(dataframe.schema, session.sessionState.conf.sessionLocalTimeZone) - }.isSuccess) { - processRowsAsArrowBatches(request.getClientId, dataframe) - } else { - processRowsAsJsonBatches(request.getClientId, dataframe) + try { + processAsArrowBatches(request.getClientId, dataframe) + } catch { + case e: Exception => + logWarning(e.getMessage) + processAsJsonBatches(request.getClientId, dataframe) } } - def processRowsAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { + def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { // Only process up to 10MB of data. val sb = new StringBuilder var rowCount = 0 @@ -127,7 +124,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte responseObserver.onCompleted() } - def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { + def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { val spark = dataframe.sparkSession val schema = dataframe.schema // TODO: control the batch size instead of max records @@ -140,7 +137,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte var numSent = 0 if (numPartitions > 0) { - type Batch = (Array[Byte], Long, Long) + type Batch = (Array[Byte], Long) val batches = rows.mapPartitionsInternal { iter => ArrowConverters @@ -148,45 +145,43 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } val signal = new Object - val partitions = Array.fill[Array[Batch]](numPartitions)(null) + val partitions = collection.mutable.Map.empty[Int, Array[Batch]] val processPartition = (iter: Iterator[Batch]) => iter.toArray + // This callback is executed by the DAGScheduler thread. + // After fetching a partition, it inserts the partition into the Map, and then + // wakes up the main thread. val resultHandler = (partitionId: Int, partition: Array[Batch]) => { signal.synchronized { partitions(partitionId) = partition signal.notify() } - val i = 0 // Unit + () } spark.sparkContext.runJob(batches, processPartition, resultHandler) + // The man thread will wait until 0-th partition is available, + // then send it to client and wait for next partition. var currentPartitionId = 0 while (currentPartitionId < numPartitions) { val partition = signal.synchronized { - while (partitions(currentPartitionId) == null) { + while (!partitions.contains(currentPartitionId)) { signal.wait() } - val partition = partitions(currentPartitionId) - partitions(currentPartitionId) = null - partition + partitions.remove(currentPartitionId).get } - // only send non-empty partitions - if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) { - partition.foreach { case (bytes, count, size) => - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.ArrowBatch - .newBuilder() - .setRowCount(count) - .setUncompressedBytes(size) - .setCompressedBytes(bytes.length) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - } + partition.foreach { case (bytes, count) => + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setRowCount(count) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) numSent += 1 } @@ -194,15 +189,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } } - // make sure at least 1 batch will be sent + // Make sure at least 1 batch will be sent. if (numSent == 0) { - val (bytes, count, size) = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) + val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) val response = proto.Response.newBuilder().setClientId(clientId) val batch = proto.Response.ArrowBatch .newBuilder() - .setRowCount(count) - .setUncompressedBytes(size) - .setCompressedBytes(bytes.length) + .setRowCount(0L) .setData(ByteString.copyFrom(bytes)) .build() response.setArrowBatch(batch) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index cc5988e7cf885..27075ff3cb027 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -400,7 +400,6 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra if len(result_dfs) > 0: df = pd.concat(result_dfs) - del result_dfs # pd.concat generates non-consecutive index like: # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index dbe93cbf87f91..1f577089d1a29 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xbb\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\x97\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -53,21 +53,21 @@ _REQUEST_USERCONTEXT._serialized_start = 429 _REQUEST_USERCONTEXT._serialized_end = 551 _RESPONSE._serialized_start = 554 - _RESPONSE._serialized_end = 1509 - _RESPONSE_ARROWBATCH._serialized_start = 794 - _RESPONSE_ARROWBATCH._serialized_end = 945 - _RESPONSE_JSONBATCH._serialized_start = 947 - _RESPONSE_JSONBATCH._serialized_end = 1007 - _RESPONSE_METRICS._serialized_start = 1010 - _RESPONSE_METRICS._serialized_end = 1494 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1094 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1404 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1292 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1404 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1406 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1494 - _ANALYZERESPONSE._serialized_start = 1512 - _ANALYZERESPONSE._serialized_end = 1646 - _SPARKCONNECTSERVICE._serialized_start = 1649 - _SPARKCONNECTSERVICE._serialized_end = 1811 + _RESPONSE._serialized_end = 1418 + _RESPONSE_ARROWBATCH._serialized_start = 793 + _RESPONSE_ARROWBATCH._serialized_end = 854 + _RESPONSE_JSONBATCH._serialized_start = 856 + _RESPONSE_JSONBATCH._serialized_end = 916 + _RESPONSE_METRICS._serialized_start = 919 + _RESPONSE_METRICS._serialized_end = 1403 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1003 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1313 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1201 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1313 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1315 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1403 + _ANALYZERESPONSE._serialized_start = 1421 + _ANALYZERESPONSE._serialized_end = 1555 + _SPARKCONNECTSERVICE._serialized_start = 1558 + _SPARKCONNECTSERVICE._serialized_end = 1720 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 97ac0dfe8b905..bf6d080d9fd97 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -178,33 +178,17 @@ class Response(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor ROW_COUNT_FIELD_NUMBER: builtins.int - UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int - COMPRESSED_BYTES_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int row_count: builtins.int - uncompressed_bytes: builtins.int - compressed_bytes: builtins.int data: builtins.bytes def __init__( self, *, row_count: builtins.int = ..., - uncompressed_bytes: builtins.int = ..., - compressed_bytes: builtins.int = ..., data: builtins.bytes = ..., ) -> None: ... def ClearField( - self, - field_name: typing_extensions.Literal[ - "compressed_bytes", - b"compressed_bytes", - "data", - b"data", - "row_count", - b"row_count", - "uncompressed_bytes", - b"uncompressed_bytes", - ], + self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] ) -> None: ... class JSONBatch(google.protobuf.message.Message): diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 866d821215b40..38c244bd74bfa 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -198,6 +198,7 @@ def test_range(self): ) def test_empty_dataset(self): + # SPARK-41005: Test arrow based collection with empty dataset. self.assertTrue( self.connect.sql("SELECT 1 AS X LIMIT 0") .toPandas() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 896e80439bf11..a2dce31bc6d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** @@ -132,7 +132,7 @@ private[sql] object ArrowConverters extends Logging { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, - timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = { + timeZoneId: String): Iterator[(Array[Byte], Long)] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( "toArrowBatchIterator", 0, Long.MaxValue) @@ -148,7 +148,7 @@ private[sql] object ArrowConverters extends Logging { } } - new Iterator[(Array[Byte], Long, Long)] { + new Iterator[(Array[Byte], Long)] { override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -156,19 +156,16 @@ private[sql] object ArrowConverters extends Logging { false } - override def next(): (Array[Byte], Long, Long) = { + override def next(): (Array[Byte], Long) = { val out = new ByteArrayOutputStream() val writeChannel = new WriteChannel(Channels.newChannel(out)) var rowCount = 0L - var estimatedSize = SizeEstimator.estimate(arrowSchema) + - SizeEstimator.estimate(IpcOption.DEFAULT) Utils.tryWithSafeFinally { while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { val row = rowIter.next() arrowWriter.write(row) rowCount += 1 - estimatedSize += SizeEstimator.estimate(row) } arrowWriter.finish() val batch = unloader.getRecordBatch() @@ -182,14 +179,14 @@ private[sql] object ArrowConverters extends Logging { arrowWriter.reset() } - (out.toByteArray, rowCount, estimatedSize) + (out.toByteArray, rowCount) } } } private[sql] def createEmptyArrowBatch( schema: StructType, - timeZoneId: String): (Array[Byte], Long, Long) = { + timeZoneId: String): Array[Byte] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( "createEmptyArrowBatch", 0, Long.MaxValue) @@ -201,8 +198,6 @@ private[sql] object ArrowConverters extends Logging { val out = new ByteArrayOutputStream() val writeChannel = new WriteChannel(Channels.newChannel(out)) - val estimatedSize = SizeEstimator.estimate(arrowSchema) + - SizeEstimator.estimate(IpcOption.DEFAULT) Utils.tryWithSafeFinally { arrowWriter.finish() val batch = unloader.getRecordBatch() // empty batch @@ -216,7 +211,7 @@ private[sql] object ArrowConverters extends Logging { arrowWriter.reset() } - (out.toByteArray, 0L, estimatedSize) + out.toByteArray } /**