diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index dff1734335e22..b376515bf1af0 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -22,6 +22,7 @@ package spark.connect; import "google/protobuf/any.proto"; import "spark/connect/commands.proto"; import "spark/connect/relations.proto"; +import "spark/connect/types.proto"; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; @@ -116,11 +117,10 @@ message Response { // reason about the performance. message AnalyzeResponse { string client_id = 1; - repeated string column_names = 2; - repeated string column_types = 3; + DataType schema = 2; // The extended explain string as produced by Spark. - string explain_string = 4; + string explain_string = 3; } // Main interface for the SparkConnect service. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index da3adce43ba98..0ee90b5e8fbbb 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._ import org.apache.spark.connect.proto import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType} /** * This object offers methods to convert to/from connect proto to catalyst types. @@ -50,11 +50,28 @@ object DataTypeProtoConverter { proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build() case StringType => proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build() + case LongType => + proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build() + case struct: StructType => + toConnectProtoStructType(struct) case _ => throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } } + def toConnectProtoStructType(schema: StructType): proto.DataType = { + val struct = proto.DataType.Struct.newBuilder() + for (structField <- schema.fields) { + struct.addFields( + proto.DataType.StructField + .newBuilder() + .setName(structField.name) + .setType(toConnectProtoType(structField.dataType)) + .setNullable(structField.nullable)) + } + proto.DataType.newBuilder().setStruct(struct).build() + } + def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = { mode match { case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 20776a29edab4..5841017e5bb71 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder import io.grpc.{Server, Status} @@ -35,7 +33,7 @@ import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, Spark import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT -import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.ExtendedMode /** @@ -89,29 +87,16 @@ class SparkConnectService(debug: Boolean) request: Request, responseObserver: StreamObserver[AnalyzeResponse]): Unit = { try { + if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) { + responseObserver.onError( + new UnsupportedOperationException( + s"${request.getPlan.getOpTypeCase} not supported for analysis.")) + } val session = SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session - - val logicalPlan = request.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.ROOT => - new SparkConnectPlanner(request.getPlan.getRoot, session).transform() - case _ => - responseObserver.onError( - new UnsupportedOperationException( - s"${request.getPlan.getOpTypeCase} not supported for analysis.")) - return - } - val ds = Dataset.ofRows(session, logicalPlan) - val explainString = ds.queryExecution.explainString(ExtendedMode) - - val resp = proto.AnalyzeResponse - .newBuilder() - .setExplainString(explainString) - .setClientId(request.getClientId) - - resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava) - resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava) - responseObserver.onNext(resp.build()) + val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session) + response.setClientId(request.getClientId) + responseObserver.onNext(response.build()) responseObserver.onCompleted() } catch { case e: Throwable => @@ -120,6 +105,20 @@ class SparkConnectService(debug: Boolean) Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) } } + + def handleAnalyzePlanRequest( + relation: proto.Relation, + session: SparkSession): proto.AnalyzeResponse.Builder = { + val logicalPlan = new SparkConnectPlanner(relation, session).transform() + + val ds = Dataset.ofRows(session, logicalPlan) + val explainString = ds.queryExecution.explainString(ExtendedMode) + + val response = proto.AnalyzeResponse + .newBuilder() + .setExplainString(explainString) + response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema)) + } } /** diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala new file mode 100644 index 0000000000000..4be8d1705b9ed --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.planner + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Testing Connect Service implementation. + */ +class SparkConnectServiceSuite extends SharedSparkSession { + + test("Test schema in analyze response") { + withTable("test") { + spark.sql(""" + | CREATE TABLE test (col1 INT, col2 STRING) + | USING parquet + |""".stripMargin) + + val instance = new SparkConnectService(false) + val relation = proto.Relation + .newBuilder() + .setRead( + proto.Read + .newBuilder() + .setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build()) + .build()) + .build() + + val response = instance.handleAnalyzePlanRequest(relation, spark) + + assert(response.getSchema.hasStruct) + val schema = response.getSchema.getStruct + assert(schema.getFieldsCount == 2) + assert( + schema.getFields(0).getName == "col1" + && schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32) + assert( + schema.getFields(1).getName == "col2" + && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING) + } + } +} diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 0ae075521c63f..f4b6d2ec302d9 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -33,6 +33,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.plan import SQL +from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType from typing import Optional, Any, Union @@ -91,14 +92,13 @@ def metrics(self) -> typing.List[MetricValue]: class AnalyzeResult: - def __init__(self, cols: typing.List[str], types: typing.List[str], explain: str): - self.cols = cols - self.types = types + def __init__(self, schema: pb2.DataType, explain: str): + self.schema = schema self.explain_string = explain @classmethod def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": - return AnalyzeResult(pb.column_names, pb.column_types, pb.explain_string) + return AnalyzeResult(pb.schema, pb.explain_string) class RemoteSparkSession(object): @@ -151,7 +151,44 @@ def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]: req.plan.CopyFrom(plan) return self._execute_and_fetch(req) - def analyze(self, plan: pb2.Plan) -> AnalyzeResult: + def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: + if schema.HasField("struct"): + structFields = [] + for proto_field in schema.struct.fields: + structFields.append( + StructField( + proto_field.name, + self._proto_schema_to_pyspark_schema(proto_field.type), + proto_field.nullable, + ) + ) + return StructType(structFields) + elif schema.HasField("i64"): + return LongType() + elif schema.HasField("string"): + return StringType() + else: + raise Exception("Only support long, string, struct conversion") + + def schema(self, plan: pb2.Plan) -> StructType: + proto_schema = self._analyze(plan).schema + # Server side should populate the struct field which is the schema. + assert proto_schema.HasField("struct") + structFields = [] + for proto_field in proto_schema.struct.fields: + structFields.append( + StructField( + proto_field.name, + self._proto_schema_to_pyspark_schema(proto_field.type), + proto_field.nullable, + ) + ) + return StructType(structFields) + + def explain_string(self, plan: pb2.Plan) -> str: + return self._analyze(plan).explain_string + + def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: req = pb2.Request() req.user_context.user_id = self._user_id req.plan.CopyFrom(plan) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 2b7e3d520391d..bf9ed83615b69 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -34,6 +34,7 @@ Expression, LiteralExpression, ) +from pyspark.sql.types import StructType if TYPE_CHECKING: from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString @@ -96,7 +97,7 @@ class DataFrame(object): of the DataFrame with the changes applied. """ - def __init__(self, data: Optional[List[Any]] = None, schema: Optional[List[str]] = None): + def __init__(self, data: Optional[List[Any]] = None, schema: Optional[StructType] = None): """Creates a new data frame""" self._schema = schema self._plan: Optional[plan.LogicalPlan] = None @@ -315,11 +316,32 @@ def toPandas(self) -> Optional["pandas.DataFrame"]: query = self._plan.to_proto(self._session) return self._session._to_pandas(query) + def schema(self) -> StructType: + """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`StructType` + """ + if self._schema is None: + if self._plan is not None: + query = self._plan.to_proto(self._session) + if self._session is None: + raise Exception("Cannot analyze without RemoteSparkSession.") + self._schema = self._session.schema(query) + return self._schema + else: + raise Exception("Empty plan.") + else: + return self._schema + def explain(self) -> str: if self._plan is not None: query = self._plan.to_proto(self._session) if self._session is None: raise Exception("Cannot analyze without RemoteSparkSession.") - return self._session.analyze(query).explain_string + return self._session.explain_string(query) else: return "" diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 408872dbb66e6..eb9ecc9157f2c 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -31,10 +31,11 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2 +from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -45,28 +46,28 @@ DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" - _PLAN._serialized_start = 131 - _PLAN._serialized_end = 247 - _REQUEST._serialized_start = 250 - _REQUEST._serialized_end = 524 - _REQUEST_USERCONTEXT._serialized_start = 402 - _REQUEST_USERCONTEXT._serialized_end = 524 - _RESPONSE._serialized_start = 527 - _RESPONSE._serialized_end = 1495 - _RESPONSE_ARROWBATCH._serialized_start = 756 - _RESPONSE_ARROWBATCH._serialized_end = 931 - _RESPONSE_JSONBATCH._serialized_start = 933 - _RESPONSE_JSONBATCH._serialized_end = 993 - _RESPONSE_METRICS._serialized_start = 996 - _RESPONSE_METRICS._serialized_end = 1480 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1278 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1392 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1480 - _ANALYZERESPONSE._serialized_start = 1498 - _ANALYZERESPONSE._serialized_end = 1653 - _SPARKCONNECTSERVICE._serialized_start = 1656 - _SPARKCONNECTSERVICE._serialized_end = 1818 + _PLAN._serialized_start = 158 + _PLAN._serialized_end = 274 + _REQUEST._serialized_start = 277 + _REQUEST._serialized_end = 551 + _REQUEST_USERCONTEXT._serialized_start = 429 + _REQUEST_USERCONTEXT._serialized_end = 551 + _RESPONSE._serialized_start = 554 + _RESPONSE._serialized_end = 1522 + _RESPONSE_ARROWBATCH._serialized_start = 783 + _RESPONSE_ARROWBATCH._serialized_end = 958 + _RESPONSE_JSONBATCH._serialized_start = 960 + _RESPONSE_JSONBATCH._serialized_end = 1020 + _RESPONSE_METRICS._serialized_start = 1023 + _RESPONSE_METRICS._serialized_end = 1507 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1305 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507 + _ANALYZERESPONSE._serialized_start = 1525 + _ANALYZERESPONSE._serialized_end = 1659 + _SPARKCONNECTSERVICE._serialized_start = 1662 + _SPARKCONNECTSERVICE._serialized_end = 1824 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index bb3a6578cf711..5ffd7701b440d 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -41,6 +41,7 @@ import google.protobuf.internal.containers import google.protobuf.message import pyspark.sql.connect.proto.commands_pb2 import pyspark.sql.connect.proto.relations_pb2 +import pyspark.sql.connect.proto.types_pb2 import sys if sys.version_info >= (3, 8): @@ -401,39 +402,27 @@ class AnalyzeResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor CLIENT_ID_FIELD_NUMBER: builtins.int - COLUMN_NAMES_FIELD_NUMBER: builtins.int - COLUMN_TYPES_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int EXPLAIN_STRING_FIELD_NUMBER: builtins.int client_id: builtins.str @property - def column_names( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - @property - def column_types( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... explain_string: builtins.str """The extended explain string as produced by Spark.""" def __init__( self, *, client_id: builtins.str = ..., - column_names: collections.abc.Iterable[builtins.str] | None = ..., - column_types: collections.abc.Iterable[builtins.str] | None = ..., + schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., explain_string: builtins.str = ..., ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["schema", b"schema"] + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "client_id", - b"client_id", - "column_names", - b"column_names", - "column_types", - b"column_types", - "explain_string", - b"explain_string", + "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema" ], ) -> None: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f6988a1d1200d..459b05cc37aad 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -22,6 +22,7 @@ import pandas from pyspark.sql import SparkSession, Row +from pyspark.sql.types import StructType, StructField, LongType, StringType from pyspark.sql.connect.client import RemoteSparkSession from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit @@ -97,6 +98,15 @@ def test_simple_explain_string(self): result = df.explain() self.assertGreater(len(result), 0) + def test_schema(self): + schema = self.connect.read.table(self.tbl_name).schema() + self.assertEqual( + StructType( + [StructField("id", LongType(), True), StructField("name", StringType(), True)] + ), + schema, + ) + def test_simple_binary_expressions(self): """Test complex expression""" df = self.connect.read.table(self.tbl_name)