Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package spark.connect;
import "google/protobuf/any.proto";
import "spark/connect/commands.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
Expand Down Expand Up @@ -116,11 +117,10 @@ message Response {
// reason about the performance.
message AnalyzeResponse {
string client_id = 1;
repeated string column_names = 2;
repeated string column_types = 3;
DataType schema = 2;

// The extended explain string as produced by Spark.
string explain_string = 4;
string explain_string = 3;
}

// Main interface for the SparkConnect service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._

import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}

/**
* This object offers methods to convert to/from connect proto to catalyst types.
Expand Down Expand Up @@ -50,11 +50,28 @@ object DataTypeProtoConverter {
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
case StringType =>
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
case LongType =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for readability, move this to the other int type?

proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
case struct: StructType =>
toConnectProtoStructType(struct)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}

def toConnectProtoStructType(schema: StructType): proto.DataType = {
val struct = proto.DataType.Struct.newBuilder()
for (structField <- schema.fields) {
struct.addFields(
proto.DataType.StructField
.newBuilder()
.setName(structField.name)
.setType(toConnectProtoType(structField.dataType))
.setNullable(structField.nullable))
}
proto.DataType.newBuilder().setStruct(struct).build()
}

def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
mode match {
case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package org.apache.spark.sql.connect.service

import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._

import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
import io.grpc.{Server, Status}
Expand All @@ -35,7 +33,7 @@ import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, Spark
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
import org.apache.spark.sql.execution.ExtendedMode

/**
Expand Down Expand Up @@ -89,29 +87,16 @@ class SparkConnectService(debug: Boolean)
request: Request,
responseObserver: StreamObserver[AnalyzeResponse]): Unit = {
try {
if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) {
responseObserver.onError(
new UnsupportedOperationException(
s"${request.getPlan.getOpTypeCase} not supported for analysis."))
}
val session =
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session

val logicalPlan = request.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.ROOT =>
new SparkConnectPlanner(request.getPlan.getRoot, session).transform()
case _ =>
responseObserver.onError(
new UnsupportedOperationException(
s"${request.getPlan.getOpTypeCase} not supported for analysis."))
return
}
val ds = Dataset.ofRows(session, logicalPlan)
val explainString = ds.queryExecution.explainString(ExtendedMode)

val resp = proto.AnalyzeResponse
.newBuilder()
.setExplainString(explainString)
.setClientId(request.getClientId)

resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava)
resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava)
responseObserver.onNext(resp.build())
val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session)
response.setClientId(request.getClientId)
responseObserver.onNext(response.build())
responseObserver.onCompleted()
} catch {
case e: Throwable =>
Expand All @@ -120,6 +105,20 @@ class SparkConnectService(debug: Boolean)
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
}
}

def handleAnalyzePlanRequest(
relation: proto.Relation,
session: SparkSession): proto.AnalyzeResponse.Builder = {
val logicalPlan = new SparkConnectPlanner(relation, session).transform()

val ds = Dataset.ofRows(session, logicalPlan)
val explainString = ds.queryExecution.explainString(ExtendedMode)

val response = proto.AnalyzeResponse
.newBuilder()
.setExplainString(explainString)
response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.test.SharedSparkSession

/**
* Testing Connect Service implementation.
*/
class SparkConnectServiceSuite extends SharedSparkSession {

test("Test schema in analyze response") {
withTable("test") {
spark.sql("""
| CREATE TABLE test (col1 INT, col2 STRING)
| USING parquet
|""".stripMargin)

val instance = new SparkConnectService(false)
val relation = proto.Relation
.newBuilder()
.setRead(
proto.Read
.newBuilder()
.setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build())
.build())
.build()

val response = instance.handleAnalyzePlanRequest(relation, spark)

assert(response.getSchema.hasStruct)
val schema = response.getSchema.getStruct
assert(schema.getFieldsCount == 2)
assert(
schema.getFields(0).getName == "col1"
&& schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32)
assert(
schema.getFields(1).getName == "col2"
&& schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING)
}
}
}
47 changes: 42 additions & 5 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.plan import SQL
from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType

from typing import Optional, Any, Union

Expand Down Expand Up @@ -91,14 +92,13 @@ def metrics(self) -> typing.List[MetricValue]:


class AnalyzeResult:
def __init__(self, cols: typing.List[str], types: typing.List[str], explain: str):
self.cols = cols
self.types = types
def __init__(self, schema: pb2.DataType, explain: str):
self.schema = schema
self.explain_string = explain

@classmethod
def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
return AnalyzeResult(pb.column_names, pb.column_types, pb.explain_string)
return AnalyzeResult(pb.schema, pb.explain_string)


class RemoteSparkSession(object):
Expand Down Expand Up @@ -151,7 +151,44 @@ def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
req.plan.CopyFrom(plan)
return self._execute_and_fetch(req)

def analyze(self, plan: pb2.Plan) -> AnalyzeResult:
def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
if schema.HasField("struct"):
structFields = []
for proto_field in schema.struct.fields:
structFields.append(
StructField(
proto_field.name,
self._proto_schema_to_pyspark_schema(proto_field.type),
proto_field.nullable,
)
)
return StructType(structFields)
elif schema.HasField("i64"):
return LongType()
elif schema.HasField("string"):
return StringType()
else:
raise Exception("Only support long, string, struct conversion")

def schema(self, plan: pb2.Plan) -> StructType:
proto_schema = self._analyze(plan).schema
# Server side should populate the struct field which is the schema.
assert proto_schema.HasField("struct")
structFields = []
for proto_field in proto_schema.struct.fields:
structFields.append(
StructField(
proto_field.name,
self._proto_schema_to_pyspark_schema(proto_field.type),
proto_field.nullable,
)
)
return StructType(structFields)

def explain_string(self, plan: pb2.Plan) -> str:
return self._analyze(plan).explain_string

def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
req = pb2.Request()
req.user_context.user_id = self._user_id
req.plan.CopyFrom(plan)
Expand Down
26 changes: 24 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Expression,
LiteralExpression,
)
from pyspark.sql.types import StructType

if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
Expand Down Expand Up @@ -96,7 +97,7 @@ class DataFrame(object):
of the DataFrame with the changes applied.
"""

def __init__(self, data: Optional[List[Any]] = None, schema: Optional[List[str]] = None):
def __init__(self, data: Optional[List[Any]] = None, schema: Optional[StructType] = None):
"""Creates a new data frame"""
self._schema = schema
self._plan: Optional[plan.LogicalPlan] = None
Expand Down Expand Up @@ -315,11 +316,32 @@ def toPandas(self) -> Optional["pandas.DataFrame"]:
query = self._plan.to_proto(self._session)
return self._session._to_pandas(query)

def schema(self) -> StructType:
"""Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`.

.. versionadded:: 3.4.0

Returns
-------
:class:`StructType`
"""
if self._schema is None:
if self._plan is not None:
query = self._plan.to_proto(self._session)
if self._session is None:
raise Exception("Cannot analyze without RemoteSparkSession.")
self._schema = self._session.schema(query)
return self._schema
else:
raise Exception("Empty plan.")
else:
return self._schema

def explain(self) -> str:
if self._plan is not None:
query = self._plan.to_proto(self._session)
if self._session is None:
raise Exception("Cannot analyze without RemoteSparkSession.")
return self._session.analyze(query).explain_string
return self._session.explain_string(query)
else:
return ""
Loading