diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 3a9157010e9b..e218b8ff9267 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -242,7 +242,7 @@ jobs:
- name: Install Python packages (Python 3.8)
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-'))
run: |
- python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting
+ python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting grpcio protobuf
python3.8 -m pip list
# Run the tests.
- name: Run tests
@@ -333,6 +333,8 @@ jobs:
pyspark-pandas
- >-
pyspark-pandas-slow
+ - >-
+ pyspark-sql-connect
env:
MODULES_TO_TEST: ${{ matrix.modules }}
HADOOP_PROFILE: ${{ inputs.hadoop }}
@@ -576,7 +578,7 @@ jobs:
# See also https://issues.apache.org/jira/browse/SPARK-38279.
python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1'
python3.9 -m pip install ipython_genutils # See SPARK-38517
- python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8'
+ python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' grpcio protobuf
python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421
apt-get update -y
apt-get install -y ruby ruby-dev
diff --git a/assembly/pom.xml b/assembly/pom.xml
index f37edcd7e49f..218bf3679504 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -74,6 +74,11 @@
spark-repl_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-connect_${scala.binary.version}
+ ${project.version}
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.12
+ 3.4.0-SNAPSHOT
+ ../pom.xml
+
+
+ spark-connect_2.12
+ jar
+ Spark Project Connect
+ https://spark.apache.org/
+
+ connect
+ 3.21.1
+ 31.0.1-jre
+ 1.47.0
+ 6.0.53
+
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
+
+
+ com.google.guava
+ guava
+ ${guava.version}
+ compile
+
+
+ com.google.guava
+ failureaccess
+ 1.0.1
+
+
+ com.google.protobuf
+ protobuf-java
+ ${protobuf.version}
+ compile
+
+
+ io.grpc
+ grpc-netty-shaded
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-protobuf
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-services
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-stub
+ ${io.grpc.version}
+
+
+ org.apache.tomcat
+ annotations-api
+ ${tomcat.annotations.api.version}
+ provided
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+
+
+
+
+
+ kr.motd.maven
+ os-maven-plugin
+ 1.6.2
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-sources
+ generate-sources
+
+ add-source
+
+
+
+ src/main/scala-${scala.binary.version}
+
+
+
+
+ add-scala-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+ src/test/gen-java
+
+
+
+
+
+
+
+ org.xolstice.maven.plugins
+ protobuf-maven-plugin
+ 0.6.1
+
+ com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
+ grpc-java
+ io.grpc:protoc-gen-grpc-java:${io.grpc.version}:exe:${os.detected.classifier}
+ src/main/protobuf
+
+
+
+
+ compile
+ compile-custom
+ test-compile
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+
+
+ com.google.guava:*
+ io.grpc:*:
+ com.google.protobuf:*
+
+
+
+
+ com.google.common
+ ${spark.shade.packageName}.connect.guava
+
+ com.google.common.**
+
+
+
+ com.google.thirdparty
+ ${spark.shade.packageName}.connect.guava
+
+ com.google.thirdparty.**
+
+
+
+ com.google.protobuf
+ ${spark.shade.packageName}.connect.protobuf
+
+ com.google.protobuf.**
+
+
+
+ io.grpc
+ ${spark.shade.packageName}.connect.grpc
+
+
+
+
+
+
+
diff --git a/connect/src/main/buf.gen.yaml b/connect/src/main/buf.gen.yaml
new file mode 100644
index 000000000000..01e31d3c8b4c
--- /dev/null
+++ b/connect/src/main/buf.gen.yaml
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+version: v1
+plugins:
+ - remote: buf.build/protocolbuffers/plugins/cpp:v3.20.0-1
+ out: gen/proto/cpp
+ - remote: buf.build/protocolbuffers/plugins/csharp:v3.20.0-1
+ out: gen/proto/csharp
+ - remote: buf.build/protocolbuffers/plugins/java:v3.20.0-1
+ out: gen/proto/java
+ - remote: buf.build/protocolbuffers/plugins/python:v3.20.0-1
+ out: gen/proto/python
+ - remote: buf.build/grpc/plugins/python:v1.47.0-1
+ out: gen/proto/python
+ - remote: buf.build/protocolbuffers/plugins/go:v1.28.0-1
+ out: gen/proto/go
+ opt:
+ - paths=source_relative
+ - remote: buf.build/grpc/plugins/go:v1.2.0-1
+ out: gen/proto/go
+ opt:
+ - paths=source_relative
+ - require_unimplemented_servers=false
+ - remote: buf.build/grpc/plugins/ruby:v1.47.0-1
+ out: gen/proto/ruby
+ - remote: buf.build/protocolbuffers/plugins/ruby:v21.2.0-1
+ out: gen/proto/ruby
diff --git a/connect/src/main/buf.work.yaml b/connect/src/main/buf.work.yaml
new file mode 100644
index 000000000000..a02dead420cd
--- /dev/null
+++ b/connect/src/main/buf.work.yaml
@@ -0,0 +1,19 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+version: v1
+directories:
+ - protobuf
diff --git a/connect/src/main/protobuf/buf.yaml b/connect/src/main/protobuf/buf.yaml
new file mode 100644
index 000000000000..496e97af3fa0
--- /dev/null
+++ b/connect/src/main/protobuf/buf.yaml
@@ -0,0 +1,23 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+version: v1
+breaking:
+ use:
+ - FILE
+lint:
+ use:
+ - DEFAULT
diff --git a/connect/src/main/protobuf/spark/connect/base.proto b/connect/src/main/protobuf/spark/connect/base.proto
new file mode 100644
index 000000000000..450f60d6aa5d
--- /dev/null
+++ b/connect/src/main/protobuf/spark/connect/base.proto
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/commands.proto";
+import "spark/connect/relations.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+// A [[Plan]] is the structure that carries the runtime information for the execution from the
+// client to the server. A [[Plan]] can either be of the type [[Relation]] which is a reference
+// to the underlying logical plan or it can be of the [[Command]] type that is used to execute
+// commands on the server.
+message Plan {
+ oneof op_type {
+ Relation root = 1;
+ Command command = 2;
+ }
+}
+
+// A request to be executed by the service.
+message Request {
+ // The client_id is set by the client to be able to collate streaming responses from
+ // different queries.
+ string client_id = 1;
+ // User context
+ UserContext user_context = 2;
+ // The logical plan to be executed / analyzed.
+ Plan plan = 3;
+
+ // User Context is used to refer to one particular user session that is executing
+ // queries in the backend.
+ message UserContext {
+ string user_id = 1;
+ string user_name = 2;
+ }
+}
+
+// The response of a query, can be one or more for each request. Responses belonging to the
+// same input query, carry the same `client_id`.
+message Response {
+ string client_id = 1;
+
+ // Result type
+ oneof result_type {
+ ArrowBatch batch = 2;
+ CSVBatch csv_batch = 3;
+ }
+
+ // Metrics for the query execution. Typically, this field is only present in the last
+ // batch of results and then represent the overall state of the query execution.
+ Metrics metrics = 4;
+
+ // Batch results of metrics.
+ message ArrowBatch {
+ int64 row_count = 1;
+ int64 uncompressed_bytes = 2;
+ int64 compressed_bytes = 3;
+ bytes data = 4;
+ bytes schema = 5;
+ }
+
+ message CSVBatch {
+ int64 row_count = 1;
+ string data = 2;
+ }
+
+ message Metrics {
+
+ repeated MetricObject metrics = 1;
+
+ message MetricObject {
+ string name = 1;
+ int64 plan_id = 2;
+ int64 parent = 3;
+ map execution_metrics = 4;
+ }
+
+ message MetricValue {
+ string name = 1;
+ int64 value = 2;
+ string metric_type = 3;
+ }
+ }
+}
+
+// Response to performing analysis of the query. Contains relevant metadata to be able to
+// reason about the performance.
+message AnalyzeResponse {
+ string client_id = 1;
+ repeated string column_names = 2;
+ repeated string column_types = 3;
+
+ // The extended explain string as produced by Spark.
+ string explain_string = 4;
+}
+
+// Main interface for the SparkConnect service.
+service SparkConnectService {
+
+ // Executes a request that contains the query and returns a stream of [[Response]].
+ rpc ExecutePlan(Request) returns (stream Response) {}
+
+ // Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query.
+ rpc AnalyzePlan(Request) returns (AnalyzeResponse) {}
+}
+
diff --git a/connect/src/main/protobuf/spark/connect/commands.proto b/connect/src/main/protobuf/spark/connect/commands.proto
new file mode 100644
index 000000000000..425857b842e5
--- /dev/null
+++ b/connect/src/main/protobuf/spark/connect/commands.proto
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+import "spark/connect/types.proto";
+
+package spark.connect;
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+// A [[Command]] is an operation that is executed by the server that does not directly consume or
+// produce a relational result.
+message Command {
+ oneof command_type {
+ CreateScalarFunction create_function = 1;
+ }
+}
+
+// Simple message that is used to create a scalar function based on the provided function body.
+//
+// This message is used to register for example a Python UDF in the session catalog by providing
+// the serialized method body.
+//
+// TODO(SPARK-40532) It is required to add the interpreter / language version to the command
+// parameters.
+message CreateScalarFunction {
+ // Fully qualified name of the function including the catalog / schema names.
+ repeated string parts = 1;
+ FunctionLanguage language = 2;
+ bool temporary = 3;
+ repeated Type argument_types = 4;
+ Type return_type = 5;
+
+ // How the function body is defined:
+ oneof function_definition {
+ // As a raw string serialized:
+ bytes serialized_function = 6;
+ // As a code literal
+ string literal_string = 7;
+ }
+
+ enum FunctionLanguage {
+ FUNCTION_LANGUAGE_UNSPECIFIED = 0;
+ FUNCTION_LANGUAGE_SQL = 1;
+ FUNCTION_LANGUAGE_PYTHON = 2;
+ FUNCTION_LANGUAGE_SCALA = 3;
+ }
+}
diff --git a/connect/src/main/protobuf/spark/connect/expressions.proto b/connect/src/main/protobuf/spark/connect/expressions.proto
new file mode 100644
index 000000000000..6b72a646623c
--- /dev/null
+++ b/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+import "spark/connect/types.proto";
+import "google/protobuf/any.proto";
+
+package spark.connect;
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+// Expression used to refer to fields, functions and similar. This can be used everywhere
+// expressions in SQL appear.
+message Expression {
+
+ oneof expr_type {
+ Literal literal = 1;
+ UnresolvedAttribute unresolved_attribute = 2;
+ UnresolvedFunction unresolved_function = 3;
+ ExpressionString expression_string = 4;
+ }
+
+ message Literal {
+ oneof literal_type {
+ bool boolean = 1;
+ int32 i8 = 2;
+ int32 i16 = 3;
+ int32 i32 = 5;
+ int64 i64 = 7;
+ float fp32 = 10;
+ double fp64 = 11;
+ string string = 12;
+ bytes binary = 13;
+ // Timestamp in units of microseconds since the UNIX epoch.
+ int64 timestamp = 14;
+ // Date in units of days since the UNIX epoch.
+ int32 date = 16;
+ // Time in units of microseconds past midnight
+ int64 time = 17;
+ IntervalYearToMonth interval_year_to_month = 19;
+ IntervalDayToSecond interval_day_to_second = 20;
+ string fixed_char = 21;
+ VarChar var_char = 22;
+ bytes fixed_binary = 23;
+ Decimal decimal = 24;
+ Struct struct = 25;
+ Map map = 26;
+ // Timestamp in units of microseconds since the UNIX epoch.
+ int64 timestamp_tz = 27;
+ bytes uuid = 28;
+ Type null = 29; // a typed null literal
+ List list = 30;
+ Type.List empty_list = 31;
+ Type.Map empty_map = 32;
+ UserDefined user_defined = 33;
+ }
+
+ // whether the literal type should be treated as a nullable type. Applies to
+ // all members of union other than the Typed null (which should directly
+ // declare nullability).
+ bool nullable = 50;
+
+ // optionally points to a type_variation_anchor defined in this plan.
+ // Applies to all members of union other than the Typed null (which should
+ // directly declare the type variation).
+ uint32 type_variation_reference = 51;
+
+ message VarChar {
+ string value = 1;
+ uint32 length = 2;
+ }
+
+ message Decimal {
+ // little-endian twos-complement integer representation of complete value
+ // (ignoring precision) Always 16 bytes in length
+ bytes value = 1;
+ // The maximum number of digits allowed in the value.
+ // the maximum precision is 38.
+ int32 precision = 2;
+ // declared scale of decimal literal
+ int32 scale = 3;
+ }
+
+ message Map {
+ message KeyValue {
+ Literal key = 1;
+ Literal value = 2;
+ }
+
+ repeated KeyValue key_values = 1;
+ }
+
+ message IntervalYearToMonth {
+ int32 years = 1;
+ int32 months = 2;
+ }
+
+ message IntervalDayToSecond {
+ int32 days = 1;
+ int32 seconds = 2;
+ int32 microseconds = 3;
+ }
+
+ message Struct {
+ // A possibly heterogeneously typed list of literals
+ repeated Literal fields = 1;
+ }
+
+ message List {
+ // A homogeneously typed list of literals
+ repeated Literal values = 1;
+ }
+
+ message UserDefined {
+ // points to a type_anchor defined in this plan
+ uint32 type_reference = 1;
+
+ // the value of the literal, serialized using some type-specific
+ // protobuf message
+ google.protobuf.Any value = 2;
+ }
+ }
+
+ // An unresolved attribute that is not explicitly bound to a specific column, but the column
+ // is resolved during analysis by name.
+ message UnresolvedAttribute {
+ repeated string parts = 1;
+ }
+
+ // An unresolved function is not explicitly bound to one explicit function, but the function
+ // is resolved during analysis following Sparks name resolution rules.
+ message UnresolvedFunction {
+ repeated string parts = 1;
+ repeated Expression arguments = 2;
+ }
+
+ // Expression as string.
+ message ExpressionString {
+ string expression = 1;
+ }
+
+}
diff --git a/connect/src/main/protobuf/spark/connect/relations.proto b/connect/src/main/protobuf/spark/connect/relations.proto
new file mode 100644
index 000000000000..adbe178da992
--- /dev/null
+++ b/connect/src/main/protobuf/spark/connect/relations.proto
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+// The main [[Relation]] type. Fundamentally, a relation is a typed container
+// that has exactly one explicit relation type set.
+//
+// When adding new relation types, they have to be registered here.
+message Relation {
+ RelationCommon common = 1;
+ oneof rel_type {
+ Read read = 2;
+ Project project = 3;
+ Filter filter = 4;
+ Join join = 5;
+ Union union = 6;
+ Sort sort = 7;
+ Fetch fetch = 8;
+ Aggregate aggregate = 9;
+ SQL sql = 10;
+
+ Unknown unknown = 999;
+ }
+}
+
+// Used for testing purposes only.
+message Unknown {}
+
+// Common metadata of all relations.
+message RelationCommon {
+ string source_info = 1;
+ string alias = 2;
+}
+
+// Relation that uses a SQL query to generate the output.
+message SQL {
+ string query = 1;
+}
+
+// Relation that reads from a file / table or other data source. Does not have additional
+// inputs.
+message Read {
+ oneof read_type {
+ NamedTable named_table = 1;
+ }
+
+ message NamedTable {
+ repeated string parts = 1;
+ }
+}
+
+// Projection of a bag of expressions for a given input relation.
+//
+// The input relation must be specified.
+// The projected expression can be an arbitrary expression.
+message Project {
+ Relation input = 1;
+ repeated Expression expressions = 3;
+}
+
+// Relation that applies a boolean expression `condition` on each row of `input` to produce
+// the output result.
+message Filter {
+ Relation input = 1;
+ Expression condition = 2;
+}
+
+// Relation of type [[Join]].
+//
+// `left` and `right` must be present.
+message Join {
+ Relation left = 1;
+ Relation right = 2;
+ Expression on = 3;
+ JoinType how = 4;
+
+ enum JoinType {
+ JOIN_TYPE_UNSPECIFIED = 0;
+ JOIN_TYPE_INNER = 1;
+ JOIN_TYPE_OUTER = 2;
+ JOIN_TYPE_LEFT_OUTER = 3;
+ JOIN_TYPE_RIGHT_OUTER = 4;
+ JOIN_TYPE_ANTI = 5;
+ }
+}
+
+// Relation of type [[Union]], at least one input must be set.
+message Union {
+ repeated Relation inputs = 1;
+ UnionType union_type = 2;
+
+ enum UnionType {
+ UNION_TYPE_UNSPECIFIED = 0;
+ UNION_TYPE_DISTINCT = 1;
+ UNION_TYPE_ALL = 2;
+ }
+}
+
+// Relation of type [[Fetch]] that is used to read `limit` / `offset` rows from the input relation.
+message Fetch {
+ Relation input = 1;
+ int32 limit = 2;
+ int32 offset = 3;
+}
+
+// Relation of type [[Aggregate]].
+message Aggregate {
+ Relation input = 1;
+
+ // Grouping sets are used in rollups
+ repeated GroupingSet grouping_sets = 2;
+
+ // Measures
+ repeated Measure measures = 3;
+
+ message GroupingSet {
+ repeated Expression aggregate_expressions = 1;
+ }
+
+ message Measure {
+ AggregateFunction function = 1;
+ // Conditional filter for SUM(x FILTER WHERE x < 10)
+ Expression filter = 2;
+ }
+
+ message AggregateFunction {
+ string name = 1;
+ repeated Expression arguments = 2;
+ }
+}
+
+// Relation of type [[Sort]].
+message Sort {
+ Relation input = 1;
+ repeated SortField sort_fields = 2;
+
+ message SortField {
+ Expression expression = 1;
+ SortDirection direction = 2;
+ SortNulls nulls = 3;
+ }
+
+ enum SortDirection {
+ SORT_DIRECTION_UNSPECIFIED = 0;
+ SORT_DIRECTION_ASCENDING = 1;
+ SORT_DIRECTION_DESCENDING = 2;
+ }
+
+ enum SortNulls {
+ SORT_NULLS_UNSPECIFIED = 0;
+ SORT_NULLS_FIRST = 1;
+ SORT_NULLS_LAST = 2;
+ }
+}
diff --git a/connect/src/main/protobuf/spark/connect/types.proto b/connect/src/main/protobuf/spark/connect/types.proto
new file mode 100644
index 000000000000..c46afa2afc65
--- /dev/null
+++ b/connect/src/main/protobuf/spark/connect/types.proto
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+// This message describes the logical [[Type]] of something. It does not carry the value
+// itself but only describes it.
+message Type {
+ oneof kind {
+ Boolean bool = 1;
+ I8 i8 = 2;
+ I16 i16 = 3;
+ I32 i32 = 5;
+ I64 i64 = 7;
+ FP32 fp32 = 10;
+ FP64 fp64 = 11;
+ String string = 12;
+ Binary binary = 13;
+ Timestamp timestamp = 14;
+ Date date = 16;
+ Time time = 17;
+ IntervalYear interval_year = 19;
+ IntervalDay interval_day = 20;
+ TimestampTZ timestamp_tz = 29;
+ UUID uuid = 32;
+
+ FixedChar fixed_char = 21;
+ VarChar varchar = 22;
+ FixedBinary fixed_binary = 23;
+ Decimal decimal = 24;
+
+ Struct struct = 25;
+ List list = 27;
+ Map map = 28;
+
+ uint32 user_defined_type_reference = 31;
+ }
+
+ enum Nullability {
+ NULLABILITY_UNSPECIFIED = 0;
+ NULLABILITY_NULLABLE = 1;
+ NULLABILITY_REQUIRED = 2;
+ }
+
+ message Boolean {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message I8 {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message I16 {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message I32 {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message I64 {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message FP32 {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message FP64 {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message String {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message Binary {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message Timestamp {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message Date {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message Time {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message TimestampTZ {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message IntervalYear {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message IntervalDay {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ message UUID {
+ uint32 type_variation_reference = 1;
+ Nullability nullability = 2;
+ }
+
+ // Start compound types.
+ message FixedChar {
+ int32 length = 1;
+ uint32 type_variation_reference = 2;
+ Nullability nullability = 3;
+ }
+
+ message VarChar {
+ int32 length = 1;
+ uint32 type_variation_reference = 2;
+ Nullability nullability = 3;
+ }
+
+ message FixedBinary {
+ int32 length = 1;
+ uint32 type_variation_reference = 2;
+ Nullability nullability = 3;
+ }
+
+ message Decimal {
+ int32 scale = 1;
+ int32 precision = 2;
+ uint32 type_variation_reference = 3;
+ Nullability nullability = 4;
+ }
+
+ message Struct {
+ repeated Type types = 1;
+ uint32 type_variation_reference = 2;
+ Nullability nullability = 3;
+ }
+
+ message List {
+ Type type = 1;
+ uint32 type_variation_reference = 2;
+ Nullability nullability = 3;
+ }
+
+ message Map {
+ Type key = 1;
+ Type value = 2;
+ uint32 type_variation_reference = 3;
+ Nullability nullability = 4;
+ }
+}
diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala
new file mode 100644
index 000000000000..d262947015cb
--- /dev/null
+++ b/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Unstable
+import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
+import org.apache.spark.sql.connect.service.SparkConnectService
+
+/**
+ * This is the main entry point for Spark Connect.
+ *
+ * To decouple the build of Spark Connect and its dependencies from the core of Spark, we
+ * implement it as a Driver Plugin. To enable Spark Connect, simply make sure that the appropriate
+ * JAR is available in the CLASSPATH and the driver plugin is configured to load this class.
+ */
+@Unstable
+class SparkConnectPlugin extends SparkPlugin {
+
+ /**
+ * Return the plugin's driver-side component.
+ *
+ * @return The driver-side component.
+ */
+ override def driverPlugin(): DriverPlugin = new DriverPlugin {
+
+ override def init(
+ sc: SparkContext,
+ pluginContext: PluginContext): util.Map[String, String] = {
+ SparkConnectService.start()
+ Map.empty[String, String].asJava
+ }
+
+ override def shutdown(): Unit = {
+ SparkConnectService.stop()
+ }
+ }
+
+ override def executorPlugin(): ExecutorPlugin = null
+}
diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
new file mode 100644
index 000000000000..865c7543609b
--- /dev/null
+++ b/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.command
+
+import scala.collection.JavaConverters._
+
+import com.google.common.collect.{Lists, Maps}
+
+import org.apache.spark.annotation.{Since, Unstable}
+import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.types.StringType
+
+
+@Unstable
+@Since("3.4.0")
+class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) {
+
+ lazy val pythonVersion =
+ sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
+
+ def process(): Unit = {
+ command.getCommandTypeCase match {
+ case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
+ handleCreateScalarFunction(command.getCreateFunction)
+ case _ => throw new UnsupportedOperationException(s"$command not supported.")
+ }
+ }
+
+ /**
+ * This is a helper function that registers a new Python function in the SparkSession.
+ *
+ * Right now this function is very rudimentary and bare-bones just to showcase how it
+ * is possible to remotely serialize a Python function and execute it on the Spark cluster.
+ * If the Python version on the client and server diverge, the execution of the function that
+ * is serialized will most likely fail.
+ *
+ * @param cf
+ */
+ def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = {
+ val function = SimplePythonFunction(
+ cf.getSerializedFunction.toByteArray,
+ Maps.newHashMap(),
+ Lists.newArrayList(),
+ pythonVersion,
+ "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
+ Lists.newArrayList(),
+ null)
+
+ val udf = UserDefinedPythonFunction(
+ cf.getPartsList.asScala.head,
+ function,
+ StringType,
+ PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = false)
+
+ session.udf.registerPython(cf.getPartsList.asScala.head, udf)
+ }
+
+}
diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
new file mode 100644
index 000000000000..ee2e05a1e606
--- /dev/null
+++ b/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.{Since, Unstable}
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.{expressions, plans}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
+import org.apache.spark.sql.types._
+
+final case class InvalidPlanInput(
+ private val message: String = "",
+ private val cause: Throwable = None.orNull)
+ extends Exception(message, cause)
+
+@Unstable
+@Since("3.4.0")
+class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
+
+ def transform(): LogicalPlan = {
+ transformRelation(plan)
+ }
+
+ // The root of the query plan is a relation and we apply the transformations to it.
+ private def transformRelation(rel: proto.Relation): LogicalPlan = {
+ val common = if (rel.hasCommon) {
+ Some(rel.getCommon)
+ } else {
+ None
+ }
+
+ rel.getRelTypeCase match {
+ case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead, common)
+ case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject, common)
+ case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
+ case proto.Relation.RelTypeCase.FETCH => transformFetch(rel.getFetch)
+ case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin)
+ case proto.Relation.RelTypeCase.UNION => transformUnion(rel.getUnion)
+ case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
+ case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
+ case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
+ case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
+ throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
+ case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
+ }
+ }
+
+ private def transformSql(sql: proto.SQL): LogicalPlan = {
+ session.sessionState.sqlParser.parsePlan(sql.getQuery)
+ }
+
+ private def transformReadRel(
+ rel: proto.Read,
+ common: Option[proto.RelationCommon]): LogicalPlan = {
+ val baseRelation = rel.getReadTypeCase match {
+ case proto.Read.ReadTypeCase.NAMED_TABLE =>
+ val child = UnresolvedRelation(rel.getNamedTable.getPartsList.asScala.toSeq)
+ if (common.nonEmpty && common.get.getAlias.nonEmpty) {
+ SubqueryAlias(identifier = common.get.getAlias, child = child)
+ } else {
+ child
+ }
+ case _ => throw InvalidPlanInput()
+ }
+ baseRelation
+ }
+
+ private def transformFilter(rel: proto.Filter): LogicalPlan = {
+ assert(rel.hasInput)
+ val baseRel = transformRelation(rel.getInput)
+ logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel)
+ }
+
+ private def transformProject(
+ rel: proto.Project,
+ common: Option[proto.RelationCommon]): LogicalPlan = {
+ val baseRel = transformRelation(rel.getInput)
+ val projection = if (rel.getExpressionsCount == 0) {
+ Seq(UnresolvedStar(Option.empty))
+ } else {
+ rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
+ }
+ val project = logical.Project(projectList = projection.toSeq, child = baseRel)
+ if (common.nonEmpty && common.get.getAlias.nonEmpty) {
+ logical.SubqueryAlias(identifier = common.get.getAlias, child = project)
+ } else {
+ project
+ }
+ }
+
+ private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {
+ UnresolvedAttribute(exp.getUnresolvedAttribute.getPartsList.asScala.toSeq)
+ }
+
+ private def transformExpression(exp: proto.Expression): Expression = {
+ exp.getExprTypeCase match {
+ case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
+ case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
+ transformUnresolvedExpression(exp)
+ case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
+ transformScalarFunction(exp.getUnresolvedFunction)
+ case _ => throw InvalidPlanInput()
+ }
+ }
+
+ /**
+ * Transforms the protocol buffers literals into the appropriate Catalyst literal expression.
+ *
+ * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp,
+ * Duration, Period.
+ * @param lit
+ * @return
+ * Expression
+ */
+ private def transformLiteral(lit: proto.Expression.Literal): Expression = {
+ lit.getLiteralTypeCase match {
+ case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => expressions.Literal(lit.getBoolean)
+ case proto.Expression.Literal.LiteralTypeCase.I8 => expressions.Literal(lit.getI8, ByteType)
+ case proto.Expression.Literal.LiteralTypeCase.I16 =>
+ expressions.Literal(lit.getI16, ShortType)
+ case proto.Expression.Literal.LiteralTypeCase.I32 => expressions.Literal(lit.getI32)
+ case proto.Expression.Literal.LiteralTypeCase.I64 => expressions.Literal(lit.getI64)
+ case proto.Expression.Literal.LiteralTypeCase.FP32 =>
+ expressions.Literal(lit.getFp32, FloatType)
+ case proto.Expression.Literal.LiteralTypeCase.FP64 =>
+ expressions.Literal(lit.getFp64, DoubleType)
+ case proto.Expression.Literal.LiteralTypeCase.STRING => expressions.Literal(lit.getString)
+ case proto.Expression.Literal.LiteralTypeCase.BINARY =>
+ expressions.Literal(lit.getBinary, BinaryType)
+ // Microseconds since unix epoch.
+ case proto.Expression.Literal.LiteralTypeCase.TIME =>
+ expressions.Literal(lit.getTime, TimestampType)
+ // Days since UNIX epoch.
+ case proto.Expression.Literal.LiteralTypeCase.DATE =>
+ expressions.Literal(lit.getDate, DateType)
+ case _ => throw InvalidPlanInput(
+ s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
+ s"(${lit.getLiteralTypeCase.name})")
+ }
+ }
+
+ private def transformFetch(limit: proto.Fetch): LogicalPlan = {
+ logical.Limit(
+ child = transformRelation(limit.getInput),
+ limitExpr = expressions.Literal(limit.getLimit, IntegerType))
+ }
+
+ private def lookupFunction(name: String, args: Seq[Expression]): Expression = {
+ UnresolvedFunction(Seq(name), args, isDistinct = false)
+ }
+
+ /**
+ * Translates a scalar function from proto to the Catalyst expression.
+ *
+ * TODO(SPARK-40546) We need to homogenize the function names for binary operators.
+ *
+ * @param fun Proto representation of the function call.
+ * @return
+ */
+ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = {
+ val funName = fun.getPartsList.asScala.mkString(".")
+ funName match {
+ case "gt" =>
+ assert(fun.getArgumentsCount == 2, "`gt` function must have two arguments.")
+ expressions.GreaterThan(
+ transformExpression(fun.getArguments(0)),
+ transformExpression(fun.getArguments(1)))
+ case "eq" =>
+ assert(fun.getArgumentsCount == 2, "`eq` function must have two arguments.")
+ expressions.EqualTo(
+ transformExpression(fun.getArguments(0)),
+ transformExpression(fun.getArguments(1)))
+ case _ =>
+ lookupFunction(funName, fun.getArgumentsList.asScala.map(transformExpression).toSeq)
+ }
+ }
+
+ private def transformUnion(u: proto.Union): LogicalPlan = {
+ assert(u.getInputsCount == 2, "Union must have 2 inputs")
+ val plan = logical.Union(transformRelation(u.getInputs(0)), transformRelation(u.getInputs(1)))
+
+ u.getUnionType match {
+ case proto.Union.UnionType.UNION_TYPE_DISTINCT => logical.Distinct(plan)
+ case proto.Union.UnionType.UNION_TYPE_ALL => plan
+ case _ =>
+ throw InvalidPlanInput(s"Unsupported set operation ${u.getUnionTypeValue}")
+ }
+ }
+
+ private def transformJoin(rel: proto.Join): LogicalPlan = {
+ assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+ logical.Join(
+ left = transformRelation(rel.getLeft),
+ right = transformRelation(rel.getRight),
+ // TODO(SPARK-40534) Support additional join types and configuration.
+ joinType = plans.Inner,
+ condition = Some(transformExpression(rel.getOn)),
+ hint = logical.JoinHint.NONE)
+ }
+
+ private def transformSort(rel: proto.Sort): LogicalPlan = {
+ assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.")
+ logical.Sort(
+ child = transformRelation(rel.getInput),
+ global = true,
+ order = rel.getSortFieldsList.asScala.map(transformSortOrderExpression).toSeq)
+ }
+
+ private def transformSortOrderExpression(so: proto.Sort.SortField): expressions.SortOrder = {
+ expressions.SortOrder(
+ child = transformUnresolvedExpression(so.getExpression),
+ direction = so.getDirection match {
+ case proto.Sort.SortDirection.SORT_DIRECTION_DESCENDING => expressions.Descending
+ case _ => expressions.Ascending
+ },
+ nullOrdering = so.getNulls match {
+ case proto.Sort.SortNulls.SORT_NULLS_LAST => expressions.NullsLast
+ case _ => expressions.NullsFirst
+ },
+ sameOrderExpressions = Seq.empty)
+ }
+
+ private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
+ assert(rel.hasInput)
+ assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported")
+
+ val groupingSet = rel.getGroupingSetsList.asScala.take(1)
+ val ge = groupingSet
+ .flatMap(f => f.getAggregateExpressionsList.asScala)
+ .map(transformExpression)
+ .map {
+ case x @ UnresolvedAttribute(_) => x
+ case x => UnresolvedAlias(x)
+ }
+
+ logical.Aggregate(
+ child = transformRelation(rel.getInput),
+ groupingExpressions = ge.toSeq,
+ aggregateExpressions =
+ (rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ ge).toSeq)
+ }
+
+ private def transformAggregateExpression(
+ exp: proto.Aggregate.Measure): expressions.NamedExpression = {
+ val fun = exp.getFunction.getName
+ UnresolvedAlias(
+ UnresolvedFunction(
+ name = fun,
+ arguments = exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq,
+ isDistinct = false))
+ }
+
+}
diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
new file mode 100644
index 000000000000..3357ad26f6c9
--- /dev/null
+++ b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+import com.google.common.base.Ticker
+import com.google.common.cache.CacheBuilder
+import io.grpc.{Server, Status}
+import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder
+import io.grpc.protobuf.services.ProtoReflectionService
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.annotation.{Since, Unstable}
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, SparkConnectServiceGrpc}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.execution.ExtendedMode
+
+/**
+ * The SparkConnectService implementation.
+ *
+ * This class implements the service stub from the generated code of GRPC.
+ *
+ * @param debug
+ * delegates debug behavior to the handlers.
+ */
+@Unstable
+@Since("3.4.0")
+class SparkConnectService(
+ debug: Boolean)
+ extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
+ with Logging {
+
+ /**
+ * This is the main entry method for Spark Connect and all calls to execute a plan.
+ *
+ * The plan execution is delegated to the [[SparkConnectStreamHandler]]. All error handling
+ * should be directly implemented in the deferred implementation. But this method catches
+ * generic errors.
+ *
+ * @param request
+ * @param responseObserver
+ */
+ override def executePlan(request: Request, responseObserver: StreamObserver[Response]): Unit = {
+ try {
+ new SparkConnectStreamHandler(responseObserver).handle(request)
+ } catch {
+ case e: Throwable =>
+ log.error("Error executing plan.", e)
+ responseObserver.onError(
+ Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
+ }
+ }
+
+ /**
+ * Analyze a plan to provide metadata and debugging information.
+ *
+ * This method is called to generate the explain plan for a SparkConnect plan. In its simplest
+ * implementation, the plan that is generated by the [[SparkConnectPlanner]] is used to build a
+ * [[Dataset]] and derive the explain string from the query execution details.
+ *
+ * Errors during planning are returned via the [[StreamObserver]] interface.
+ *
+ * @param request
+ * @param responseObserver
+ */
+ override def analyzePlan(
+ request: Request,
+ responseObserver: StreamObserver[AnalyzeResponse]): Unit = {
+ try {
+ val session =
+ SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
+
+ val logicalPlan = request.getPlan.getOpTypeCase match {
+ case proto.Plan.OpTypeCase.ROOT =>
+ new SparkConnectPlanner(request.getPlan.getRoot, session).transform()
+ case _ =>
+ responseObserver.onError(
+ new UnsupportedOperationException(
+ s"${request.getPlan.getOpTypeCase} not supported for analysis."))
+ return
+ }
+ val ds = Dataset.ofRows(session, logicalPlan)
+ val explainString = ds.queryExecution.explainString(ExtendedMode)
+
+ val resp = proto.AnalyzeResponse
+ .newBuilder()
+ .setExplainString(explainString)
+ .setClientId(request.getClientId)
+
+ resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava)
+ resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava)
+ responseObserver.onNext(resp.build())
+ responseObserver.onCompleted()
+ } catch {
+ case e: Throwable =>
+ log.error("Error analyzing plan.", e)
+ responseObserver.onError(
+ Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
+ }
+ }
+}
+
+/**
+ * Object used for referring to SparkSessions in the SessionCache.
+ *
+ * @param userId
+ * @param session
+ */
+@Unstable
+@Since("3.4.0")
+private[connect] case class SessionHolder(userId: String, session: SparkSession)
+
+/**
+ * Static instance of the SparkConnectService.
+ *
+ * Used to start the overall SparkConnect service and provides global state to manage the
+ * different SparkSession from different users connecting to the cluster.
+ */
+@Unstable
+@Since("3.4.0")
+object SparkConnectService {
+
+ private val CACHE_SIZE = 100
+
+ private val CACHE_TIMEOUT_SECONDS = 3600
+
+ // Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a
+ // different or complex type easily.
+ private type SessionCacheKey = String;
+
+ private var server: Server = _
+
+ private val userSessionMapping =
+ cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, SessionHolder]()
+
+ // Simple builder for creating the cache of Sessions.
+ private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): CacheBuilder[Object, Object] = {
+ var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker())
+ if (cacheSize >= 0) {
+ cacheBuilder = cacheBuilder.maximumSize(cacheSize)
+ }
+ if (timeoutSeconds >= 0) {
+ cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS)
+ }
+ cacheBuilder
+ }
+
+ /**
+ * Based on the `key` find or create a new SparkSession.
+ */
+ private[connect] def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = {
+ userSessionMapping.get(
+ key,
+ () => {
+ SessionHolder(key, newIsolatedSession())
+ })
+ }
+
+ private def newIsolatedSession(): SparkSession = {
+ SparkSession.active.newSession()
+ }
+
+ /**
+ * Starts the GRPC Serivce.
+ *
+ * TODO(SPARK-40536) Make port number configurable.
+ */
+ def startGRPCService(): Unit = {
+ val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
+ val port = 15002
+ val sb = NettyServerBuilder
+ .forPort(port)
+ .addService(new SparkConnectService(debugMode))
+
+ // If debug mode is configured, load the ProtoReflection service so that tools like
+ // grpcurl can introspect the API for debugging.
+ if (debugMode) {
+ sb.addService(ProtoReflectionService.newInstance())
+ }
+ server = sb.build
+ server.start()
+ }
+
+ // Starts the service
+ def start(): Unit = {
+ startGRPCService()
+ }
+
+ def stop(): Unit = {
+ if (server != null) {
+ server.shutdownNow()
+ }
+ }
+}
+
diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
new file mode 100644
index 000000000000..52b807f63bb0
--- /dev/null
+++ b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.ByteString
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.annotation.{Since, Unstable}
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{Request, Response}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.util.ArrowUtils
+
+
+@Unstable
+@Since("3.4.0")
+class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {
+
+ def handle(v: Request): Unit = {
+ val session =
+ SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session
+ v.getPlan.getOpTypeCase match {
+ case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
+ case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
+ case _ =>
+ throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
+ }
+ }
+
+ def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+ // Extract the plan from the request and convert it to a logical plan
+ val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
+ val rows =
+ Dataset.ofRows(session, planner.transform())
+ processRows(request.getClientId, rows)
+ }
+
+ private def processRows(clientId: String, rows: DataFrame) = {
+ val timeZoneId = SQLConf.get.sessionLocalTimeZone
+ val schema =
+ ByteString.copyFrom(ArrowUtils.toArrowSchema(rows.schema, timeZoneId).toByteArray)
+
+ val textSchema = rows.schema.fields.map(f => f.name).mkString("|")
+ val data = rows.collect().map(x => x.toSeq.mkString("|")).mkString("\n")
+ val bbb = proto.Response.CSVBatch.newBuilder
+ .setRowCount(-1)
+ .setData(textSchema ++ "\n" ++ data)
+ .build()
+ val response = proto.Response.newBuilder().setClientId(clientId).setCsvBatch(bbb).build()
+
+ // Send all the data
+ responseObserver.onNext(response)
+ responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+ responseObserver.onCompleted()
+ }
+
+ def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = {
+ // Send a last batch with the metrics
+ Response
+ .newBuilder()
+ .setClientId(clientId)
+ .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan))
+ .build()
+ }
+
+ def handleCommand(session: SparkSession, request: Request): Unit = {
+ val command = request.getPlan.getCommand
+ val planner = new SparkConnectCommandPlanner(session, command)
+ planner.process()
+ responseObserver.onCompleted()
+ }
+}
+
+object MetricGenerator extends AdaptiveSparkPlanHelper {
+ def buildMetrics(p: SparkPlan): Response.Metrics = {
+ val b = Response.Metrics.newBuilder
+ b.addAllMetrics(transformPlan(p, p.id).asJava)
+ b.build()
+ }
+
+ def transformChildren(p: SparkPlan): Seq[Response.Metrics.MetricObject] = {
+ allChildren(p).flatMap(c => transformPlan(c, p.id))
+ }
+
+ def allChildren(p: SparkPlan): Seq[SparkPlan] = p match {
+ case a: AdaptiveSparkPlanExec => Seq(a.executedPlan)
+ case s: QueryStageExec => Seq(s.plan)
+ case _ => p.children
+ }
+
+ def transformPlan(p: SparkPlan, parentId: Int): Seq[Response.Metrics.MetricObject] = {
+ val mv = p.metrics.map(m =>
+ m._1 -> Response.Metrics.MetricValue.newBuilder
+ .setName(m._2.name.getOrElse(""))
+ .setValue(m._2.value)
+ .setMetricType(m._2.metricType)
+ .build())
+ val mo = Response.Metrics.MetricObject
+ .newBuilder()
+ .setName(p.nodeName)
+ .setPlanId(p.id)
+ .putAllExecutionMetrics(mv.asJava)
+ .build()
+ Seq(mo) ++ transformChildren(p)
+ }
+
+}
diff --git a/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
new file mode 100644
index 000000000000..3c18994ba7f8
--- /dev/null
+++ b/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SparkFunSuite, TestUtils}
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * Testing trait for SparkConnect tests with some helper methods to make it easier to create new
+ * test cases.
+ */
+trait SparkConnectPlanTest {
+ def transform(rel: proto.Relation): LogicalPlan = {
+ new SparkConnectPlanner(rel, None.orNull).transform()
+ }
+
+ def readRel: proto.Relation =
+ proto.Relation
+ .newBuilder()
+ .setRead(
+ proto.Read
+ .newBuilder()
+ .setNamedTable(proto.Read.NamedTable.newBuilder().addParts("table"))
+ .build())
+ .build()
+}
+
+trait SparkConnectSessionTest {
+ protected var spark: SparkSession
+
+}
+
+/**
+ * This is a rudimentary test class for SparkConnect. The main goal of these basic tests is to
+ * ensure that the transformation from Proto to LogicalPlan works and that the right nodes are
+ * generated.
+ */
+class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
+
+ protected var spark: SparkSession = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ TestUtils.configTestLog4j2("INFO")
+ }
+
+ test("Simple Limit") {
+ assertThrows[IndexOutOfBoundsException] {
+ new SparkConnectPlanner(
+ proto.Relation.newBuilder.setFetch(proto.Fetch.newBuilder.setLimit(10).build()).build(),
+ None.orNull)
+ .transform()
+ }
+ }
+
+ test("InvalidInputs") {
+ // No Relation Set
+ intercept[IndexOutOfBoundsException](
+ new SparkConnectPlanner(proto.Relation.newBuilder().build(), None.orNull).transform())
+
+ intercept[InvalidPlanInput](
+ new SparkConnectPlanner(
+ proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build(),
+ None.orNull).transform())
+
+ }
+
+ test("Simple Read") {
+ val read = proto.Read.newBuilder().build()
+ // Invalid read without Table name.
+ intercept[InvalidPlanInput](transform(proto.Relation.newBuilder.setRead(read).build()))
+ val readWithTable = read.toBuilder
+ .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build())
+ .build()
+ val res = transform(proto.Relation.newBuilder.setRead(readWithTable).build())
+ assert(res !== null)
+ assert(res.nodeName == "UnresolvedRelation")
+ }
+
+ test("Simple Sort") {
+ val sort = proto.Sort.newBuilder
+ .addAllSortFields(Seq(proto.Sort.SortField.newBuilder().build()).asJava)
+ .build()
+ intercept[IndexOutOfBoundsException](
+ transform(proto.Relation.newBuilder().setSort(sort).build()),
+ "No Input set.")
+
+ val f = proto.Sort.SortField
+ .newBuilder()
+ .setNulls(proto.Sort.SortNulls.SORT_NULLS_LAST)
+ .setDirection(proto.Sort.SortDirection.SORT_DIRECTION_DESCENDING)
+ .setExpression(proto.Expression.newBuilder
+ .setUnresolvedAttribute(
+ proto.Expression.UnresolvedAttribute.newBuilder.addAllParts(Seq("col").asJava).build())
+ .build())
+ .build()
+
+ val res = transform(
+ proto.Relation.newBuilder
+ .setSort(proto.Sort.newBuilder.addAllSortFields(Seq(f).asJava).setInput(readRel))
+ .build())
+ assert(res.nodeName == "Sort")
+ }
+
+ test("Simple Union") {
+ intercept[AssertionError](
+ transform(proto.Relation.newBuilder.setUnion(proto.Union.newBuilder.build()).build))
+ val union = proto.Relation.newBuilder
+ .setUnion(proto.Union.newBuilder.addAllInputs(Seq(readRel, readRel).asJava).build())
+ .build()
+ val msg = intercept[InvalidPlanInput] {
+ transform(union)
+ }
+ assert(msg.getMessage.contains("Unsupported set operation"))
+
+ val res = transform(
+ proto.Relation.newBuilder
+ .setUnion(
+ proto.Union.newBuilder
+ .addAllInputs(Seq(readRel, readRel).asJava)
+ .setUnionType(proto.Union.UnionType.UNION_TYPE_ALL)
+ .build())
+ .build())
+ assert(res.nodeName == "Union")
+ }
+
+ test("Simple Join") {
+
+ val incompleteJoin =
+ proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build()
+ intercept[AssertionError](transform(incompleteJoin))
+
+ // Cartesian Product not supported.
+ intercept[InvalidPlanInput] {
+ val simpleJoin = proto.Relation.newBuilder
+ .setJoin(proto.Join.newBuilder.setLeft(readRel).setRight(readRel))
+ .build()
+ transform(simpleJoin)
+ }
+
+ // Construct a simple Join.
+ val unresolvedAttribute = proto.Expression
+ .newBuilder()
+ .setUnresolvedAttribute(
+ proto.Expression.UnresolvedAttribute.newBuilder().addAllParts(Seq("left").asJava).build())
+ .build()
+
+ val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
+ proto.Expression.UnresolvedFunction.newBuilder
+ .addAllParts(Seq("eq").asJava)
+ .addArguments(unresolvedAttribute)
+ .addArguments(unresolvedAttribute)
+ .build())
+
+ val simpleJoin = proto.Relation.newBuilder
+ .setJoin(
+ proto.Join.newBuilder.setLeft(readRel).setRight(readRel).setOn(joinCondition).build())
+ .build()
+
+ val res = transform(simpleJoin)
+ assert(res.nodeName == "Join")
+ assert(res != null)
+
+ }
+
+ test("Simple Projection") {
+ val project = proto.Project.newBuilder
+ .setInput(readRel)
+ .addExpressions(
+ proto.Expression.newBuilder
+ .setLiteral(proto.Expression.Literal.newBuilder.setI32(32))
+ .build())
+ .build()
+
+ val res = transform(proto.Relation.newBuilder.setProject(project).build())
+ assert(res.nodeName == "Project")
+
+ }
+
+ test("Simple Aggregation") {
+ val unresolvedAttribute = proto.Expression
+ .newBuilder()
+ .setUnresolvedAttribute(
+ proto.Expression.UnresolvedAttribute.newBuilder().addAllParts(Seq("left").asJava).build())
+ .build()
+
+ val agg = proto.Aggregate.newBuilder
+ .setInput(readRel)
+ .addAllMeasures(
+ Seq(
+ proto.Aggregate.Measure.newBuilder
+ .setFunction(proto.Aggregate.AggregateFunction.newBuilder
+ .setName("sum")
+ .addArguments(unresolvedAttribute))
+ .build()).asJava)
+ .addGroupingSets(proto.Aggregate.GroupingSet.newBuilder
+ .addAggregateExpressions(unresolvedAttribute)
+ .build())
+ .build()
+
+ val res = transform(proto.Relation.newBuilder.setAggregate(agg).build())
+ assert(res.nodeName == "Aggregate")
+ }
+
+}
diff --git a/dev/.rat-excludes b/dev/.rat-excludes
index e1cc000c064f..23806bcd3069 100644
--- a/dev/.rat-excludes
+++ b/dev/.rat-excludes
@@ -139,3 +139,6 @@ exported_table/*
ansible-for-test-node/*
node_modules
spark-events-broken/*
+# Spark Connect related files with custom licence
+any.proto
+empty.proto
diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile
index c6555e0463db..e20319eceefa 100644
--- a/dev/create-release/spark-rm/Dockerfile
+++ b/dev/create-release/spark-rm/Dockerfile
@@ -42,7 +42,7 @@ ARG APT_INSTALL="apt-get install --no-install-recommends -y"
# We should use the latest Sphinx version once this is fixed.
# TODO(SPARK-35375): Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375.
-ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17"
+ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.48.1 protobuf==4.21.5"
ARG GEM_PKGS="bundler:2.2.9"
# Install extra needed repos and refresh.
diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3
index bb9a8cad170a..4b04a4997d82 100644
--- a/dev/deps/spark-deps-hadoop-2-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-2-hive-2.3
@@ -6,7 +6,9 @@ ST4/4.0.4//ST4-4.0.4.jar
activation/1.1.1//activation-1.1.1.jar
aircompressor/0.21//aircompressor-0.21.jar
algebra_2.12/2.0.1//algebra_2.12-2.0.1.jar
+animal-sniffer-annotations/1.19//animal-sniffer-annotations-1.19.jar
annotations/17.0.0//annotations-17.0.0.jar
+annotations/4.1.1.4//annotations-4.1.1.4.jar
antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar
antlr4-runtime/4.8//antlr4-runtime-4.8.jar
aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar
@@ -63,10 +65,20 @@ datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar
datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar
derby/10.14.2.0//derby-10.14.2.0.jar
dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar
+error_prone_annotations/2.10.0//error_prone_annotations-2.10.0.jar
+failureaccess/1.0.1//failureaccess-1.0.1.jar
flatbuffers-java/1.12.0//flatbuffers-java-1.12.0.jar
gcs-connector/hadoop2-2.2.7/shaded/gcs-connector-hadoop2-2.2.7-shaded.jar
generex/1.0.2//generex-1.0.2.jar
gmetric4j/1.0.10//gmetric4j-1.0.10.jar
+grpc-api/1.47.0//grpc-api-1.47.0.jar
+grpc-context/1.47.0//grpc-context-1.47.0.jar
+grpc-core/1.47.0//grpc-core-1.47.0.jar
+grpc-netty-shaded/1.47.0//grpc-netty-shaded-1.47.0.jar
+grpc-protobuf-lite/1.47.0//grpc-protobuf-lite-1.47.0.jar
+grpc-protobuf/1.47.0//grpc-protobuf-1.47.0.jar
+grpc-services/1.47.0//grpc-services-1.47.0.jar
+grpc-stub/1.47.0//grpc-stub-1.47.0.jar
gson/2.2.4//gson-2.2.4.jar
guava/14.0.1//guava-14.0.1.jar
guice-servlet/3.0//guice-servlet-3.0.jar
@@ -112,6 +124,7 @@ httpclient/4.5.13//httpclient-4.5.13.jar
httpcore/4.4.14//httpcore-4.4.14.jar
istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar
ivy/2.5.0//ivy-2.5.0.jar
+j2objc-annotations/1.3//j2objc-annotations-1.3.jar
jackson-annotations/2.13.4//jackson-annotations-2.13.4.jar
jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar
jackson-core/2.13.4//jackson-core-2.13.4.jar
@@ -231,7 +244,10 @@ parquet-encoding/1.12.3//parquet-encoding-1.12.3.jar
parquet-format-structures/1.12.3//parquet-format-structures-1.12.3.jar
parquet-hadoop/1.12.3//parquet-hadoop-1.12.3.jar
parquet-jackson/1.12.3//parquet-jackson-1.12.3.jar
+perfmark-api/0.25.0//perfmark-api-0.25.0.jar
pickle/1.2//pickle-1.2.jar
+proto-google-common-protos/2.0.1//proto-google-common-protos-2.0.1.jar
+protobuf-java-util/3.19.2//protobuf-java-util-3.19.2.jar
protobuf-java/2.5.0//protobuf-java-2.5.0.jar
py4j/0.10.9.7//py4j-0.10.9.7.jar
remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 555579e6446f..be322eb78674 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -10,7 +10,9 @@ aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar
aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar
aliyun-java-sdk-ram/3.1.0//aliyun-java-sdk-ram-3.1.0.jar
aliyun-sdk-oss/3.13.0//aliyun-sdk-oss-3.13.0.jar
+animal-sniffer-annotations/1.19//animal-sniffer-annotations-1.19.jar
annotations/17.0.0//annotations-17.0.0.jar
+annotations/4.1.1.4//annotations-4.1.1.4.jar
antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar
antlr4-runtime/4.8//antlr4-runtime-4.8.jar
aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar
@@ -60,10 +62,20 @@ datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar
datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar
derby/10.14.2.0//derby-10.14.2.0.jar
dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar
+error_prone_annotations/2.10.0//error_prone_annotations-2.10.0.jar
+failureaccess/1.0.1//failureaccess-1.0.1.jar
flatbuffers-java/1.12.0//flatbuffers-java-1.12.0.jar
gcs-connector/hadoop3-2.2.7/shaded/gcs-connector-hadoop3-2.2.7-shaded.jar
generex/1.0.2//generex-1.0.2.jar
gmetric4j/1.0.10//gmetric4j-1.0.10.jar
+grpc-api/1.47.0//grpc-api-1.47.0.jar
+grpc-context/1.47.0//grpc-context-1.47.0.jar
+grpc-core/1.47.0//grpc-core-1.47.0.jar
+grpc-netty-shaded/1.47.0//grpc-netty-shaded-1.47.0.jar
+grpc-protobuf-lite/1.47.0//grpc-protobuf-lite-1.47.0.jar
+grpc-protobuf/1.47.0//grpc-protobuf-1.47.0.jar
+grpc-services/1.47.0//grpc-services-1.47.0.jar
+grpc-stub/1.47.0//grpc-stub-1.47.0.jar
gson/2.2.4//gson-2.2.4.jar
guava/14.0.1//guava-14.0.1.jar
hadoop-aliyun/3.3.4//hadoop-aliyun-3.3.4.jar
@@ -100,6 +112,7 @@ httpcore/4.4.14//httpcore-4.4.14.jar
ini4j/0.5.4//ini4j-0.5.4.jar
istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar
ivy/2.5.0//ivy-2.5.0.jar
+j2objc-annotations/1.3//j2objc-annotations-1.3.jar
jackson-annotations/2.13.4//jackson-annotations-2.13.4.jar
jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar
jackson-core/2.13.4//jackson-core-2.13.4.jar
@@ -218,7 +231,10 @@ parquet-encoding/1.12.3//parquet-encoding-1.12.3.jar
parquet-format-structures/1.12.3//parquet-format-structures-1.12.3.jar
parquet-hadoop/1.12.3//parquet-hadoop-1.12.3.jar
parquet-jackson/1.12.3//parquet-jackson-1.12.3.jar
+perfmark-api/0.25.0//perfmark-api-0.25.0.jar
pickle/1.2//pickle-1.2.jar
+proto-google-common-protos/2.0.1//proto-google-common-protos-2.0.1.jar
+protobuf-java-util/3.19.2//protobuf-java-util-3.19.2.jar
protobuf-java/2.5.0//protobuf-java-2.5.0.jar
py4j/0.10.9.7//py4j-0.10.9.7.jar
remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar
diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile
index 2bf0be3822db..1fc885edb520 100644
--- a/dev/infra/Dockerfile
+++ b/dev/infra/Dockerfile
@@ -65,3 +65,6 @@ RUN Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='ht
# See more in SPARK-39735
ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library"
+
+# Add Python deps for Spark Connect.
+RUN python3.9 -m pip install grpcio protobuf
diff --git a/dev/requirements.txt b/dev/requirements.txt
index 7771b97a7320..7803e4f0736b 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -45,3 +45,7 @@ PyGithub
# pandas API on Spark Code formatter.
black==22.6.0
+
+# Spark Connect
+grpcio==1.48.1
+protobuf==4.21.5
\ No newline at end of file
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 4cfb0639ed3a..fc34b4c61b0a 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -251,6 +251,7 @@ def build_spark_sbt(extra_profiles):
sbt_goals = [
"Test/package", # Build test jars as some tests depend on them
"streaming-kinesis-asl-assembly/assembly",
+ "connect/assembly", # Build Spark Connect assembly
]
profiles_and_goals = build_profiles + sbt_goals
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2b9d52693794..422d353ec044 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -472,6 +472,24 @@ def __hash__(self):
],
)
+pyspark_sql = Module(
+ name="pyspark-sql-connect",
+ dependencies=[pyspark_core, hive, avro],
+ source_file_regexes=["python/pyspark/sql/connect"],
+ python_test_goals=[
+ # doctests
+ # No doctests yet.
+ # unittests
+ "pyspark.sql.tests.connect.test_column_expressions",
+ "pyspark.sql.tests.connect.test_plan_only",
+ "pyspark.sql.tests.connect.test_select_ops",
+ "pyspark.sql.tests.connect.test_spark_connect",
+ ],
+ excluded_python_implementations=[
+ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
+ # they aren't available there
+ ],
+)
pyspark_resource = Module(
name="pyspark-resource",
diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py
index 94928fa8730c..800b7e0d9325 100755
--- a/dev/sparktestsupport/utils.py
+++ b/dev/sparktestsupport/utils.py
@@ -109,22 +109,22 @@ def determine_modules_to_test(changed_modules, deduplicated=True):
>>> [x.name for x in determine_modules_to_test([modules.sql])]
... # doctest: +NORMALIZE_WHITESPACE
['sql', 'avro', 'docker-integration-tests', 'hive', 'mllib', 'sql-kafka-0-10', 'examples',
- 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr',
+ 'hive-thriftserver', 'pyspark-sql', 'pyspark-sql-connect', 'repl', 'sparkr',
'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-ml']
>>> sorted([x.name for x in determine_modules_to_test(
... [modules.sparkr, modules.sql], deduplicated=False)])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', 'mllib',
'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-sql',
- 'repl', 'sparkr', 'sql', 'sql-kafka-0-10']
+ 'pyspark-sql-connect', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10']
>>> sorted([x.name for x in determine_modules_to_test(
... [modules.sql, modules.core], deduplicated=False)])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'catalyst', 'core', 'docker-integration-tests', 'examples', 'graphx', 'hive',
'hive-thriftserver', 'mllib', 'mllib-local', 'pyspark-core', 'pyspark-ml', 'pyspark-mllib',
'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-resource', 'pyspark-sql',
- 'pyspark-streaming', 'repl', 'root', 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming',
- 'streaming-kafka-0-10', 'streaming-kinesis-asl']
+ 'pyspark-sql-connect', 'pyspark-streaming', 'repl', 'root', 'sparkr', 'sql',
+ 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', 'streaming-kinesis-asl']
"""
modules_to_test = set()
for module in changed_modules:
diff --git a/dev/tox.ini b/dev/tox.ini
index 464b9b959fa1..f44cbe54ddf6 100644
--- a/dev/tox.ini
+++ b/dev/tox.ini
@@ -51,4 +51,5 @@ exclude =
python/pyspark/worker.pyi,
python/pyspark/java_gateway.pyi,
dev/ansible-for-test-node/*,
+ python/pyspark/sql/connect/proto/*,
max-line-length = 100
diff --git a/pom.xml b/pom.xml
index 9f382692e536..5cac5e30bde8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -100,6 +100,7 @@
connector/kafka-0-10-assemblyconnector/kafka-0-10-sqlconnector/avro
+ connect
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 0c887e0e70ed..e7e24954eb16 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -39,6 +39,8 @@ import sbtassembly.AssemblyPlugin.autoImport._
import spray.revolver.RevolverPlugin._
+import sbtprotoc.ProtocPlugin.autoImport._
+
object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
@@ -50,12 +52,14 @@ object BuildCommons {
val streamingProjects@Seq(streaming, streamingKafka010) =
Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _))
+ val connect = ProjectRef(buildLocation, "connect")
+
val allProjects@Seq(
core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _*
) = Seq(
"core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
"tags", "sketch", "kvstore"
- ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
+ ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect)
val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn,
sparkGangliaLgpl, streamingKinesisAsl,
@@ -79,6 +83,11 @@ object BuildCommons {
val testTempDir = s"$sparkHome/target/tmp"
val javaVersion = settingKey[String]("source and target JVM version for javac and scalac")
+
+ // Google Protobuf version used for generating the protobuf.
+ val protoVersion = "3.21.1"
+ // GRPC version used for Spark Connect.
+ val gprcVersion = "1.47.0"
}
object SparkBuild extends PomBuild {
@@ -357,7 +366,11 @@ object SparkBuild extends PomBuild {
// To prevent intermittent compilation failures, see also SPARK-33297
// Apparently we can remove this when we use JDK 11.
- Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat
+ Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat,
+
+ // Setting version for the protobuf compiler. This has to be propagated to every sub-project
+ // even if the project is not using it.
+ PB.protocVersion := protoVersion,
)
def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
@@ -377,7 +390,7 @@ object SparkBuild extends PomBuild {
val mimaProjects = allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn,
- unsafe, tags, tokenProviderKafka010, sqlKafka010
+ unsafe, tags, tokenProviderKafka010, sqlKafka010, connect
).contains(x)
}
@@ -418,6 +431,8 @@ object SparkBuild extends PomBuild {
/* Hive console settings */
enable(Hive.settings)(hive)
+ enable(SparkConnect.settings)(connect)
+
// SPARK-14738 - Remove docker tests from main Spark build
// enable(DockerIntegrationTests.settings)(dockerIntegrationTests)
@@ -593,6 +608,60 @@ object Core {
)
}
+
+object SparkConnect {
+
+ import BuildCommons.protoVersion
+
+ private val shadePrefix = "org.sparkproject.connect"
+ val shadeJar = taskKey[Unit]("Shade the Jars")
+
+ lazy val settings = Seq(
+ // Setting version for the protobuf compiler. This has to be propagated to every sub-project
+ // even if the project is not using it.
+ PB.protocVersion := BuildCommons.protoVersion,
+
+ // For some reason the resolution from the imported Maven build does not work for some
+ // of these dependendencies that we need to shade later on.
+ libraryDependencies ++= Seq(
+ "io.grpc" % "protoc-gen-grpc-java" % BuildCommons.gprcVersion asProtocPlugin(),
+ "org.scala-lang" % "scala-library" % "2.12.16" % "provided",
+ "com.google.guava" % "guava" % "31.0.1-jre",
+ "com.google.guava" % "failureaccess" % "1.0.1",
+ "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
+ ),
+
+ dependencyOverrides ++= Seq(
+ "com.google.guava" % "guava" % "31.0.1-jre",
+ "com.google.guava" % "failureaccess" % "1.0.1",
+ "com.google.protobuf" % "protobuf-java" % protoVersion
+ ),
+
+ (Compile / PB.targets) := Seq(
+ PB.gens.java -> (Compile / sourceManaged).value,
+ PB.gens.plugin("grpc-java") -> (Compile / sourceManaged).value
+ ),
+
+ (assembly / test) := false,
+
+ (assembly / logLevel) := Level.Info,
+
+ (assembly / assemblyShadeRules) := Seq(
+ ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@0").inAll,
+ ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll,
+ ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll,
+ ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll,
+ ),
+
+ (assembly / assemblyMergeStrategy) := {
+ case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard
+ // Drop all proto files that are not needed as artifacts of the build.
+ case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard
+ case _ => MergeStrategy.first
+ },
+ )
+}
+
object Unsafe {
lazy val settings = Seq(
// This option is needed to suppress warnings from sun.misc.Unsafe usage
@@ -741,6 +810,8 @@ object ExcludedDependencies {
*/
object OldDeps {
+ import BuildCommons.protoVersion
+
lazy val project = Project("oldDeps", file("dev"))
.settings(oldDepsSettings)
.disablePlugins(com.typesafe.sbt.pom.PomReaderPlugin)
@@ -753,6 +824,9 @@ object OldDeps {
}
def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq(
+ // Setting version for the protobuf compiler. This has to be propagated to every sub-project
+ // even if the project is not using it.
+ PB.protocVersion := protoVersion,
name := "old-deps",
libraryDependencies := allPreviousArtifactKeys.value.flatten
)
@@ -1033,10 +1107,10 @@ object Unidoc {
(ScalaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
- yarn, tags, streamingKafka010, sqlKafka010),
+ yarn, tags, streamingKafka010, sqlKafka010, connect),
(JavaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
- yarn, tags, streamingKafka010, sqlKafka010),
+ yarn, tags, streamingKafka010, sqlKafka010, connect),
(ScalaUnidoc / unidoc / unidocAllClasspaths) := {
ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value)
@@ -1118,14 +1192,25 @@ object CopyDependencies {
throw new IOException("Failed to create jars directory.")
}
+ // For the SparkConnect build, we manually call the assembly target to
+ // produce the shaded Jar which happens automatically in the case of Maven.
+ // Later, when the dependencies are copied, we manually copy the shaded Jar only.
+ val fid = (LocalProject("connect")/assembly).value
+
(Compile / dependencyClasspath).value.map(_.data)
.filter { jar => jar.isFile() }
+ // Do not copy the Spark Connect JAR as it is unshaded in the SBT build.
.foreach { jar =>
val destJar = new File(dest, jar.getName())
if (destJar.isFile()) {
destJar.delete()
}
- Files.copy(jar.toPath(), destJar.toPath())
+ if (jar.getName.contains("spark-connect") &&
+ !SbtPomKeys.profiles.value.contains("noshade-connect")) {
+ Files.copy(fid.toPath, destJar.toPath)
+ } else {
+ Files.copy(jar.toPath(), destJar.toPath())
+ }
}
},
(Compile / packageBin / crossTarget) := destPath.value,
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 9f55e21dc9d2..19f9da7351dc 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -44,3 +44,5 @@ libraryDependencies += "org.ow2.asm" % "asm-commons" % "9.3"
addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3")
addSbtPlugin("com.typesafe.sbt" % "sbt-pom-reader" % "2.2.0")
+
+addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.1")
diff --git a/python/mypy.ini b/python/mypy.ini
index efaa3dc97d3c..4f015641af28 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -23,6 +23,18 @@ show_error_codes = True
warn_unused_ignores = True
warn_redundant_casts = True
+; TODO(SPARK-40537) reenable mypi support.
+[mypy-pyspark.sql.connect.*]
+disallow_untyped_defs = False
+ignore_missing_imports = True
+ignore_errors = True
+
+; TODO(SPARK-40537) reenable mypi support.
+[mypy-pyspark.sql.tests.connect.*]
+disallow_untyped_defs = False
+ignore_missing_imports = True
+ignore_errors = True
+
; Allow untyped def in internal modules and tests
[mypy-pyspark.daemon]
@@ -138,3 +150,10 @@ ignore_missing_imports = True
[mypy-tabulate.*]
ignore_missing_imports = True
+
+[mypy-google.protobuf.*]
+ignore_missing_imports = True
+
+; Ignore errors for proto generated code
+[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto]
+ignore_errors = True
diff --git a/python/pyspark/sql/connect/README.md b/python/pyspark/sql/connect/README.md
new file mode 100644
index 000000000000..e79e9aae9dd2
--- /dev/null
+++ b/python/pyspark/sql/connect/README.md
@@ -0,0 +1,38 @@
+
+# [EXPERIMENTAL] Spark Connect
+
+**Spark Connect is a strictly experimental feature and under heavy development.
+All APIs should be considered volatile and should not be used in production.**
+
+This module contains the implementation of Spark Connect which is a logical plan
+facade for the implementation in Spark. Spark Connect is directly integrated into the build
+of Spark. To enable it, you only need to activate the driver plugin for Spark Connect.
+
+
+
+
+## Build
+
+1. Build Spark as usual per the documentation.
+2. Build and package the Spark Connect package
+ ```bash
+ ./build/mvn -Phive package
+ ```
+ or
+ ```shell
+ ./build/sbt -Phive package
+ ```
+
+## Run Spark Shell
+
+```bash
+./bin/spark-shell --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
+```
+
+## Run Tests
+
+
+```bash
+./run-tests --testnames 'pyspark.sql.tests.connect.test_spark_connect'
+```
+
diff --git a/python/pyspark/sql/connect/__init__.py b/python/pyspark/sql/connect/__init__.py
new file mode 100644
index 000000000000..c748f8f6590e
--- /dev/null
+++ b/python/pyspark/sql/connect/__init__.py
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Currently Spark Connect is very experimental and the APIs to interact with
+Spark through this API are can be changed at any time without warning."""
+
+
+from pyspark.sql.connect.data_frame import DataFrame # noqa: F401
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
new file mode 100644
index 000000000000..df42411d42cb
--- /dev/null
+++ b/python/pyspark/sql/connect/client.py
@@ -0,0 +1,180 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import io
+import logging
+import typing
+import uuid
+
+import grpc
+import pandas
+import pandas as pd
+import pyarrow as pa
+
+import pyspark.sql.connect.proto as pb2
+import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
+from pyspark import cloudpickle
+from pyspark.sql.connect.data_frame import DataFrame
+from pyspark.sql.connect.readwriter import DataFrameReader
+from pyspark.sql.connect.plan import SQL
+
+
+NumericType = typing.Union[int, float]
+
+logging.basicConfig(level=logging.INFO)
+
+
+class MetricValue:
+ def __init__(self, name: str, value: NumericType, type: str):
+ self._name = name
+ self._type = type
+ self._value = value
+
+ def __repr__(self) -> str:
+ return f"<{self._name}={self._value} ({self._type})>"
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def value(self) -> NumericType:
+ return self._value
+
+ @property
+ def metric_type(self) -> str:
+ return self._type
+
+
+class PlanMetrics:
+ def __init__(self, name: str, id: str, parent: str, metrics: typing.List[MetricValue]):
+ self._name = name
+ self._id = id
+ self._parent_id = parent
+ self._metrics = metrics
+
+ def __repr__(self) -> str:
+ return f"Plan({self._name})={self._metrics}"
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def plan_id(self) -> str:
+ return self._id
+
+ @property
+ def parent_plan_id(self) -> str:
+ return self._parent_id
+
+ @property
+ def metrics(self) -> typing.List[MetricValue]:
+ return self._metrics
+
+
+class AnalyzeResult:
+ def __init__(self, cols: typing.List[str], types: typing.List[str], explain: str):
+ self.cols = cols
+ self.types = types
+ self.explain_string = explain
+
+ @classmethod
+ def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
+ return AnalyzeResult(pb.column_names, pb.column_types, pb.explain_string)
+
+
+class RemoteSparkSession(object):
+ """Conceptually the remote spark session that communicates with the server"""
+
+ def __init__(self, user_id: str, host: str = None, port: int = 15002):
+ self._host = "localhost" if host is None else host
+ self._port = port
+ self._user_id = user_id
+ self._channel = grpc.insecure_channel(f"{self._host}:{self._port}")
+ self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
+
+ # Create the reader
+ self.read = DataFrameReader(self)
+
+ def register_udf(self, function, return_type) -> str:
+ """Create a temporary UDF in the session catalog on the other side. We generate a
+ temporary name for it."""
+ name = f"fun_{uuid.uuid4().hex}"
+ fun = pb2.CreateScalarFunction()
+ fun.parts.append(name)
+ fun.serialized_function = cloudpickle.dumps((function, return_type))
+
+ req = pb2.Request()
+ req.user_context.user_id = self._user_id
+ req.plan.command.create_function.CopyFrom(fun)
+
+ self._execute_and_fetch(req)
+ return name
+
+ def _build_metrics(self, metrics: "pb2.Response.Metrics") -> typing.List[PlanMetrics]:
+ return [
+ PlanMetrics(
+ x.name,
+ x.plan_id,
+ x.parent,
+ [MetricValue(k, v.value, v.metric_type) for k, v in x.execution_metrics.items()],
+ )
+ for x in metrics.metrics
+ ]
+
+ def sql(self, sql_string: str) -> "DataFrame":
+ return DataFrame.withPlan(SQL(sql_string), self)
+
+ def collect(self, plan: pb2.Plan) -> pandas.DataFrame:
+ req = pb2.Request()
+ req.user_context.user_id = self._user_id
+ req.plan.CopyFrom(plan)
+ return self._execute_and_fetch(req)
+
+ def analyze(self, plan: pb2.Plan) -> AnalyzeResult:
+ req = pb2.Request()
+ req.user_context.user_id = self._user_id
+ req.plan.CopyFrom(plan)
+
+ resp = self._stub.AnalyzePlan(req)
+ return AnalyzeResult.fromProto(resp)
+
+ def _process_batch(self, b) -> pandas.DataFrame:
+ if b.batch is not None and len(b.batch.data) > 0:
+ with pa.ipc.open_stream(b.data) as rd:
+ return rd.read_pandas()
+ elif b.csv_batch is not None and len(b.csv_batch.data) > 0:
+ return pd.read_csv(io.StringIO(b.csv_batch.data), delimiter="|")
+
+ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]:
+ m = None
+ result_dfs = []
+
+ for b in self._stub.ExecutePlan(req):
+ if b.metrics is not None:
+ m = b.metrics
+ result_dfs.append(self._process_batch(b))
+
+ if len(result_dfs) > 0:
+ df = pd.concat(result_dfs)
+ # Attach the metrics to the DataFrame attributes.
+ df.attrs["metrics"] = self._build_metrics(m)
+ return df
+ else:
+ return None
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
new file mode 100644
index 000000000000..391771f14fc9
--- /dev/null
+++ b/python/pyspark/sql/connect/column.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Union, cast, get_args, TYPE_CHECKING
+
+import pyspark.sql.connect.proto as proto
+
+PrimitiveType = Union[str, int, bool, float]
+ExpressionOrString = Union[str, "Expression"]
+ColumnOrString = Union[str, "ColumnRef"]
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect.client import RemoteSparkSession
+ import pyspark.sql.connect.proto as proto
+
+
+class Expression(object):
+ """
+ Expression base class.
+ """
+
+ def __init__(self) -> None: # type: ignore[name-defined]
+ pass
+
+ def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": # type: ignore
+ ...
+
+ def __str__(self) -> str:
+ ...
+
+
+class LiteralExpression(Expression):
+ """A literal expression.
+
+ The Python types are converted best effort into the relevant proto types. On the Spark Connect
+ server side, the proto types are converted to the Catalyst equivalents."""
+
+ def __init__(self, value: PrimitiveType) -> None: # type: ignore[name-defined]
+ super().__init__()
+ self._value = value
+
+ def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
+ """Converts the literal expression to the literal in proto.
+
+ TODO(SPARK-40533) This method always assumes the largest type and can thus
+ create weird interpretations of the literal."""
+ value_type = type(self._value)
+ exp = proto.Expression()
+ if value_type is int:
+ exp.literal.i32 = cast(int, self._value)
+ elif value_type is str:
+ exp.literal.string = cast(str, self._value)
+ elif value_type is float:
+ exp.literal.fp64 = cast(float, self._value)
+ else:
+ raise ValueError(f"Could not convert literal for type {type(self._value)}")
+
+ return exp
+
+ def __str__(self) -> str:
+ return f"Literal({self._value})"
+
+
+def _bin_op(name: str, doc: str = "binary function", reverse=False):
+ def _(self: "ColumnRef", other) -> Expression:
+ if isinstance(other, get_args(PrimitiveType)):
+ other = LiteralExpression(other)
+ if not reverse:
+ return ScalarFunctionExpression(name, self, other)
+ else:
+ return ScalarFunctionExpression(name, other, self)
+
+ return _
+
+
+class ColumnRef(Expression):
+ """Represents a column reference. There is no guarantee that this column
+ actually exists. In the context of this project, we refer by its name and
+ treat it as an unresolved attribute. Attributes that have the same fully
+ qualified name are identical"""
+
+ @classmethod
+ def from_qualified_name(cls, name) -> "ColumnRef":
+ return ColumnRef(*name.split("."))
+
+ def __init__(self, *parts: str) -> None: # type: ignore[name-defined]
+ super().__init__()
+ self._parts: List[str] = list(filter(lambda x: x is not None, list(parts)))
+
+ def name(self) -> str:
+ """Returns the qualified name of the column reference."""
+ return ".".join(self._parts)
+
+ __gt__ = _bin_op("gt")
+ __lt__ = _bin_op("lt")
+ __add__ = _bin_op("plus")
+ __sub__ = _bin_op("minus")
+ __mul__ = _bin_op("multiply")
+ __div__ = _bin_op("divide")
+ __truediv__ = _bin_op("divide")
+ __mod__ = _bin_op("modulo")
+ __radd__ = _bin_op("plus", reverse=True)
+ __rsub__ = _bin_op("minus", reverse=True)
+ __rmul__ = _bin_op("multiply", reverse=True)
+ __rdiv__ = _bin_op("divide", reverse=True)
+ __rtruediv__ = _bin_op("divide", reverse=True)
+ __pow__ = _bin_op("pow")
+ __rpow__ = _bin_op("pow", reverse=True)
+ __ge__ = _bin_op("greterEquals")
+ __le__ = _bin_op("lessEquals")
+
+ def __eq__(self, other) -> Expression: # type: ignore[override]
+ """Returns a binary expression with the current column as the left
+ side and the other expression as the right side.
+ """
+ if isinstance(other, get_args(PrimitiveType)):
+ other = LiteralExpression(other)
+ return ScalarFunctionExpression("eq", self, other)
+
+ def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+ """Returns the Proto representation of the expression."""
+ expr = proto.Expression()
+ expr.unresolved_attribute.parts.extend(self._parts)
+ return expr
+
+ def desc(self):
+ return SortOrder(self, ascending=False)
+
+ def asc(self):
+ return SortOrder(self, ascending=True)
+
+ def __str__(self) -> str:
+ return f"Column({'.'.join(self._parts)})"
+
+
+class SortOrder(Expression):
+ def __init__(self, col: ColumnRef, ascending=True, nullsLast=True) -> None:
+ super().__init__()
+ self.ref = col
+ self.ascending = ascending
+ self.nullsLast = nullsLast
+
+ def __str__(self) -> str:
+ return str(self.ref) + " ASC" if self.ascending else " DESC"
+
+ def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+ return self.ref.to_plan()
+
+
+class ScalarFunctionExpression(Expression):
+ def __init__(
+ self,
+ op: str,
+ *args: Expression,
+ ) -> None:
+ super().__init__()
+ self._args = args
+ self._op = op
+
+ def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
+ fun = proto.Expression()
+ fun.unresolved_function.parts.append(self._op)
+ fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args])
+ return fun
+
+ def __str__(self) -> str:
+ return f"({self._op} ({', '.join([str(x) for x in self._args])}))"
diff --git a/python/pyspark/sql/connect/data_frame.py b/python/pyspark/sql/connect/data_frame.py
new file mode 100644
index 000000000000..b229cad19803
--- /dev/null
+++ b/python/pyspark/sql/connect/data_frame.py
@@ -0,0 +1,236 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import (
+ Any,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+ TYPE_CHECKING,
+)
+
+import pyspark.sql.connect.plan as plan
+from pyspark.sql.connect.column import (
+ ColumnOrString,
+ ColumnRef,
+ Expression,
+ ExpressionOrString,
+ LiteralExpression,
+)
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect.client import RemoteSparkSession
+
+
+ColumnOrName = Union[ColumnRef, str]
+
+
+class GroupingFrame(object):
+
+ MeasuresType = Union[Sequence[Tuple[ExpressionOrString, str]], Dict[str, str]]
+ OptMeasuresType = Optional[MeasuresType]
+
+ def __init__(self, df: "DataFrame", *grouping_cols: Union[ColumnRef, str]) -> None:
+ self._df = df
+ self._grouping_cols = [x if isinstance(x, ColumnRef) else df[x] for x in grouping_cols]
+
+ def agg(self, exprs: MeasuresType = None) -> "DataFrame":
+
+ # Normalize the dictionary into a list of tuples.
+ if isinstance(exprs, Dict):
+ measures = list(exprs.items())
+ elif isinstance(exprs, List):
+ measures = exprs
+ else:
+ measures = []
+
+ res = DataFrame.withPlan(
+ plan.Aggregate(
+ child=self._df._plan,
+ grouping_cols=self._grouping_cols,
+ measures=measures,
+ ),
+ session=self._df._session,
+ )
+ return res
+
+ def _map_cols_to_dict(self, fun: str, cols: List[Union[ColumnRef, str]]) -> Dict[str, str]:
+ return {x if isinstance(x, str) else cast(ColumnRef, x).name(): fun for x in cols}
+
+ def min(self, *cols: Union[ColumnRef, str]) -> "DataFrame":
+ expr = self._map_cols_to_dict("min", list(cols))
+ return self.agg(expr)
+
+ def max(self, *cols: Union[ColumnRef, str]) -> "DataFrame":
+ expr = self._map_cols_to_dict("max", list(cols))
+ return self.agg(expr)
+
+ def sum(self, *cols: Union[ColumnRef, str]) -> "DataFrame":
+ expr = self._map_cols_to_dict("sum", list(cols))
+ return self.agg(expr)
+
+ def count(self) -> "DataFrame":
+ return self.agg([(LiteralExpression(1), "count")])
+
+
+class DataFrame(object):
+ """Every DataFrame object essentially is a Relation that is refined using the
+ member functions. Calling a method on a dataframe will essentially return a copy
+ of the DataFrame with the changes applied.
+ """
+
+ def __init__(self, data: List[Any] = None, schema: List[str] = None):
+ """Creates a new data frame"""
+ self._schema = schema
+ self._plan: Optional[plan.LogicalPlan] = None
+ self._cache: Dict[str, Any] = {}
+ self._session: "RemoteSparkSession" = None
+
+ @classmethod
+ def withPlan(cls, plan: plan.LogicalPlan, session=None) -> "DataFrame":
+ """Main initialization method used to construct a new data frame with a child plan."""
+ new_frame = DataFrame()
+ new_frame._plan = plan
+ new_frame._session = session
+ return new_frame
+
+ def select(self, *cols: ColumnRef) -> "DataFrame":
+ return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session)
+
+ def agg(self, exprs: Dict[str, str]) -> "DataFrame":
+ return self.groupBy().agg(exprs)
+
+ def alias(self, alias):
+ return DataFrame.withPlan(plan.Project(self._plan).withAlias(alias), session=self._session)
+
+ def approxQuantile(self, col, probabilities, relativeError):
+ ...
+
+ def colRegex(self, regex) -> "DataFrame":
+ ...
+
+ @property
+ def columns(self) -> List[str]:
+ """Returns the list of columns of the current data frame."""
+ if self._plan is None:
+ return []
+ if "columns" not in self._cache and self._plan is not None:
+ pdd = self.limit(0).collect()
+ # Translate to standard pytho array
+ self._cache["columns"] = pdd.columns.values
+ return self._cache["columns"]
+
+ def count(self):
+ """Returns the number of rows in the data frame"""
+ return self.agg([(LiteralExpression(1), "count")]).collect().iloc[0, 0]
+
+ def crossJoin(self, other):
+ ...
+
+ def coalesce(self, num_partitions: int) -> "DataFrame":
+ ...
+
+ def describe(self, cols):
+ ...
+
+ def distinct(self) -> "DataFrame":
+ """Returns all distinct rows."""
+ all_cols = self.columns()
+ gf = self.groupBy(*all_cols)
+ return gf.agg()
+
+ def drop(self, *cols: ColumnOrString):
+ all_cols = self.columns()
+ dropped = set([c.name() if isinstance(c, ColumnRef) else self[c].name() for c in cols])
+ filter(lambda x: x in dropped, all_cols)
+
+ def filter(self, condition: Expression) -> "DataFrame":
+ return DataFrame.withPlan(
+ plan.Filter(child=self._plan, filter=condition), session=self._session
+ )
+
+ def first(self):
+ return self.head(1)
+
+ def groupBy(self, *cols: ColumnOrString):
+ return GroupingFrame(self, *cols)
+
+ def head(self, n: int):
+ self.limit(n)
+ return self.collect()
+
+ def join(self, other, on, how=None):
+ return DataFrame.withPlan(
+ plan.Join(left=self._plan, right=other._plan, on=on, how=how),
+ session=self._session,
+ )
+
+ def limit(self, n):
+ return DataFrame.withPlan(plan.Limit(child=self._plan, limit=n), session=self._session)
+
+ def sort(self, *cols: ColumnOrName):
+ """Sort by a specific column"""
+ return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session)
+
+ def show(self, n: int, truncate: Optional[Union[bool, int]], vertical: Optional[bool]):
+ ...
+
+ def union(self, other) -> "DataFrame":
+ return self.unionAll(other)
+
+ def unionAll(self, other: "DataFrame") -> "DataFrame":
+ if other._plan is None:
+ raise ValueError("Argument to Union does not contain a valid plan.")
+ return DataFrame.withPlan(plan.UnionAll(self._plan, other._plan), session=self._session)
+
+ def where(self, condition):
+ return self.filter(condition)
+
+ def _get_alias(self):
+ p = self._plan
+ while p is not None:
+ if isinstance(p, plan.Project) and p.alias:
+ return p.alias
+ p = p._child
+
+ def __getattr__(self, name) -> "ColumnRef":
+ return self[name]
+
+ def __getitem__(self, name) -> "ColumnRef":
+ # Check for alias
+ alias = self._get_alias()
+ return ColumnRef(alias, name)
+
+ def _print_plan(self) -> str:
+ if self._plan:
+ return self._plan.print()
+ return ""
+
+ def collect(self):
+ query = self._plan.collect(self._session)
+ return self._session.collect(query)
+
+ def toPandas(self):
+ return self.collect()
+
+ def explain(self) -> str:
+ query = self._plan.collect(self._session)
+ return self._session.analyze(query).explain_string
diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py
new file mode 100644
index 000000000000..9f4e457030c5
--- /dev/null
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+from typing import TYPE_CHECKING
+
+import pyspark.sql.types
+from pyspark.sql.connect.column import (
+ ColumnOrString,
+ ColumnRef,
+ Expression,
+ ExpressionOrString,
+ ScalarFunctionExpression,
+)
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect.client import RemoteSparkSession
+
+
+def _build(name: str, *args: ExpressionOrString) -> ScalarFunctionExpression:
+ """
+ Simple wrapper function that converts the arguments into the appropriate types.
+ Parameters
+ ----------
+ name Name of the function to be called.
+ args The list of arguments.
+
+ Returns
+ -------
+ :class:`ScalarFunctionExpression`
+ """
+ cols = [x if isinstance(x, Expression) else ColumnRef.from_qualified_name(x) for x in args]
+ return ScalarFunctionExpression(name, *cols)
+
+
+class FunctionBuilder:
+ """This class is used to build arbitrary functions used in expressions"""
+
+ def __getattr__(self, name):
+ def _(*args: ExpressionOrString) -> ScalarFunctionExpression:
+ return _build(name, *args)
+
+ _.__doc__ = f"""Function to apply {name}"""
+ return _
+
+
+functions = FunctionBuilder()
+
+
+class UserDefinedFunction(Expression):
+ """A user defied function is an expresison that has a reference to the actual
+ Python callable attached. During plan generation, the client sends a command to
+ the server to register the UDF before execution. The expression object can be
+ reused and is not attached to a specific execution. If the internal name of
+ the temporary function is set, it is assumed that the registration has already
+ happened."""
+
+ def __init__(self, func, return_type=pyspark.sql.types.StringType(), args=None):
+ super().__init__()
+
+ self._func_ref = func
+ self._return_type = return_type
+ self._args = list(args)
+ self._func_name = None
+
+ def to_plan(self, session: "RemoteSparkSession") -> Expression:
+ # Needs to materialize the UDF to the server
+ # Only do this once per session
+ func_name = session.register_udf(self._func_ref, self._return_type)
+ # Func name is used for the actual reference
+ return _build(func_name, *self._args).to_plan(session)
+
+ def __str__(self):
+ return f"UserDefinedFunction({self._func_name})"
+
+
+def _create_udf(function, return_type):
+ def wrapper(*cols: "ColumnOrString"):
+ return UserDefinedFunction(func=function, return_type=return_type, args=cols)
+
+ return wrapper
+
+
+def udf(function, return_type=pyspark.sql.types.StringType()):
+ """
+ Returns a callable that represents the column once arguments are applied
+
+ Parameters
+ ----------
+ function
+ return_type
+
+ Returns
+ -------
+
+ """
+ # This is when @udf / @udf(DataType()) is used
+ if function is None or isinstance(function, (str, pyspark.sql.types.DataType)):
+ return_type = function or return_type
+ # Overload with
+ if return_type is None:
+ return_type = pyspark.sql.types.StringType()
+ return functools.partial(_create_udf, return_type=return_type)
+ else:
+ return _create_udf(function, return_type)
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
new file mode 100644
index 000000000000..4fe57d922837
--- /dev/null
+++ b/python/pyspark/sql/connect/functions.py
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.sql.connect.column import ColumnRef, LiteralExpression
+from pyspark.sql.connect.column import PrimitiveType
+
+# TODO(SPARK-40538) Add support for the missing PySpark functions.
+
+
+def col(x: str) -> ColumnRef:
+ return ColumnRef(x)
+
+
+def lit(x: PrimitiveType) -> LiteralExpression:
+ return LiteralExpression(x)
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
new file mode 100644
index 000000000000..d404236610d5
--- /dev/null
+++ b/python/pyspark/sql/connect/plan.py
@@ -0,0 +1,464 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import (
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+ TYPE_CHECKING,
+)
+
+import pyspark.sql.connect.proto as proto
+from pyspark.sql.connect.column import (
+ ColumnOrString,
+ ColumnRef,
+ Expression,
+ ExpressionOrString,
+ SortOrder,
+)
+
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect.client import RemoteSparkSession
+
+
+class InputValidationError(Exception):
+ pass
+
+
+class LogicalPlan(object):
+
+ INDENT = 2
+
+ def __init__(self, child: Optional["LogicalPlan"]) -> None:
+ self._child = child
+
+ def unresolved_attr(self, *colNames: str) -> proto.Expression:
+ """Creates an unresolved attribute from a column name."""
+ exp = proto.Expression()
+ exp.unresolved_attribute.parts.extend(list(colNames))
+ return exp
+
+ def to_attr_or_expression(
+ self, col: ColumnOrString, session: "RemoteSparkSession"
+ ) -> proto.Expression:
+ """Returns either an instance of an unresolved attribute or the serialized
+ expression value of the column."""
+ if type(col) is str:
+ return self.unresolved_attr(cast(str, col))
+ else:
+ return cast(ColumnRef, col).to_plan(session)
+
+ def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+ ...
+
+ def _verify(self, session: "RemoteSparkSession") -> bool:
+ """This method is used to verify that the current logical plan
+ can be serialized to Proto and back and afterwards is identical."""
+ plan = proto.Plan()
+ plan.root.CopyFrom(self.plan(session))
+
+ serialized_plan = plan.SerializeToString()
+ test_plan = proto.Plan()
+ test_plan.ParseFromString(serialized_plan)
+
+ return test_plan == plan
+
+ def collect(self, session: "RemoteSparkSession" = None, debug: bool = False):
+ plan = proto.Plan()
+ plan.root.CopyFrom(self.plan(session))
+
+ if debug:
+ print(plan)
+
+ return plan
+
+ def print(self, indent=0) -> str:
+ ...
+
+ def _repr_html_(self):
+ ...
+
+
+class Read(LogicalPlan):
+ def __init__(self, table_name: str) -> None:
+ super().__init__(None)
+ self.table_name = table_name
+
+ def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+ plan = proto.Relation()
+ plan.read.named_table.parts.extend(self.table_name.split("."))
+ return plan
+
+ def print(self, indent=0) -> str:
+ return f"{' ' * indent}\n"
+
+ def _repr_html_(self):
+ return f"""
+
+
+ Read
+ table name: {self.table_name}
+
+
+ """
+
+
+class Project(LogicalPlan):
+ """Logical plan object for a projection.
+
+ All input arguments are directly serialized into the corresponding protocol buffer
+ objects. This class only provides very limited error handling and input validation.
+
+ To be compatible with PySpark, we validate that the input arguments are all
+ expressions to be able to serialize them to the server.
+
+ """
+
+ def __init__(self, child: Optional["LogicalPlan"], *columns: ExpressionOrString) -> None:
+ super().__init__(child)
+ self._raw_columns = list(columns)
+ self.alias = None
+ self._verify_expressions()
+
+ def _verify_expressions(self):
+ """Ensures that all input arguments are instances of Expression."""
+ for c in self._raw_columns:
+ if not isinstance(c, Expression):
+ raise InputValidationError(f"Only Expressions can be used for projections: '{c}'.")
+
+ def withAlias(self, alias) -> LogicalPlan:
+ self.alias = alias
+ return self
+
+ def plan(self, session: "RemoteSparkSession") -> proto.Relation:
+ assert self._child is not None
+ proj_exprs = [
+ c.to_plan(session)
+ if isinstance(c, Expression)
+ else self.unresolved_attr(*cast(str, c).split("."))
+ for c in self._raw_columns
+ ]
+ common = proto.RelationCommon()
+ if self.alias is not None:
+ common.alias = self.alias
+
+ plan = proto.Relation()
+ plan.project.input.CopyFrom(self._child.plan(session))
+ plan.project.expressions.extend(proj_exprs)
+ plan.common.CopyFrom(common)
+ return plan
+
+ def print(self, indent=0) -> str:
+ c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
+ return f"{' ' * indent}\n{c_buf}"
+
+ def _repr_html_(self):
+ return f"""
+
+
+ Project
+ Columns: {",".join([str(c) for c in self._raw_columns])}
+ {self._child._repr_html_()}
+
+ """
diff --git a/python/pyspark/sql/connect/proto/__init__.py b/python/pyspark/sql/connect/proto/__init__.py
new file mode 100644
index 000000000000..f00b1a74c1d8
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/__init__.py
@@ -0,0 +1,23 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.connect.proto.base_pb2_grpc import *
+from pyspark.sql.connect.proto.base_pb2 import *
+from pyspark.sql.connect.proto.types_pb2 import *
+from pyspark.sql.connect.proto.commands_pb2 import *
+from pyspark.sql.connect.proto.expressions_pb2 import *
+from pyspark.sql.connect.proto.relations_pb2 import *
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
new file mode 100644
index 000000000000..3adb77f77d60
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: spark/connect/base.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from pyspark.sql.connect.proto import (
+ commands_pb2 as spark_dot_connect_dot_commands__pb2,
+)
+from pyspark.sql.connect.proto import (
+ relations_pb2 as spark_dot_connect_dot_relations__pb2,
+)
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xdb\x01\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x43\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName"\xc4\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12?\n\tcsv_batch\x18\x03 \x01(\x0b\x32 .spark.connect.Response.CSVBatchH\x00R\x08\x63svBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a;\n\x08\x43SVBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\tR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42M\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3'
+)
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.base_pb2", globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = (
+ b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto"
+ )
+ _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
+ _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001"
+ _PLAN._serialized_start = 104
+ _PLAN._serialized_end = 220
+ _REQUEST._serialized_start = 223
+ _REQUEST._serialized_end = 442
+ _REQUEST_USERCONTEXT._serialized_start = 375
+ _REQUEST_USERCONTEXT._serialized_end = 442
+ _RESPONSE._serialized_start = 445
+ _RESPONSE._serialized_end = 1409
+ _RESPONSE_ARROWBATCH._serialized_start = 671
+ _RESPONSE_ARROWBATCH._serialized_end = 846
+ _RESPONSE_CSVBATCH._serialized_start = 848
+ _RESPONSE_CSVBATCH._serialized_end = 907
+ _RESPONSE_METRICS._serialized_start = 910
+ _RESPONSE_METRICS._serialized_end = 1394
+ _RESPONSE_METRICS_METRICOBJECT._serialized_start = 994
+ _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1304
+ _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1192
+ _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1304
+ _RESPONSE_METRICS_METRICVALUE._serialized_start = 1306
+ _RESPONSE_METRICS_METRICVALUE._serialized_end = 1394
+ _ANALYZERESPONSE._serialized_start = 1412
+ _ANALYZERESPONSE._serialized_end = 1567
+ _SPARKCONNECTSERVICE._serialized_start = 1570
+ _SPARKCONNECTSERVICE._serialized_end = 1732
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
new file mode 100644
index 000000000000..77307603f6eb
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
+
+
+class SparkConnectServiceStub(object):
+ """Main interface for the SparkConnect service."""
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.ExecutePlan = channel.unary_stream(
+ "/spark.connect.SparkConnectService/ExecutePlan",
+ request_serializer=spark_dot_connect_dot_base__pb2.Request.SerializeToString,
+ response_deserializer=spark_dot_connect_dot_base__pb2.Response.FromString,
+ )
+ self.AnalyzePlan = channel.unary_unary(
+ "/spark.connect.SparkConnectService/AnalyzePlan",
+ request_serializer=spark_dot_connect_dot_base__pb2.Request.SerializeToString,
+ response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzeResponse.FromString,
+ )
+
+
+class SparkConnectServiceServicer(object):
+ """Main interface for the SparkConnect service."""
+
+ def ExecutePlan(self, request, context):
+ """Executes a request that contains the query and returns a stream of [[Response]]."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def AnalyzePlan(self, request, context):
+ """Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+
+def add_SparkConnectServiceServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ "ExecutePlan": grpc.unary_stream_rpc_method_handler(
+ servicer.ExecutePlan,
+ request_deserializer=spark_dot_connect_dot_base__pb2.Request.FromString,
+ response_serializer=spark_dot_connect_dot_base__pb2.Response.SerializeToString,
+ ),
+ "AnalyzePlan": grpc.unary_unary_rpc_method_handler(
+ servicer.AnalyzePlan,
+ request_deserializer=spark_dot_connect_dot_base__pb2.Request.FromString,
+ response_serializer=spark_dot_connect_dot_base__pb2.AnalyzeResponse.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ "spark.connect.SparkConnectService", rpc_method_handlers
+ )
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+# This class is part of an EXPERIMENTAL API.
+class SparkConnectService(object):
+ """Main interface for the SparkConnect service."""
+
+ @staticmethod
+ def ExecutePlan(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_stream(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/ExecutePlan",
+ spark_dot_connect_dot_base__pb2.Request.SerializeToString,
+ spark_dot_connect_dot_base__pb2.Response.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def AnalyzePlan(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/AnalyzePlan",
+ spark_dot_connect_dot_base__pb2.Request.SerializeToString,
+ spark_dot_connect_dot_base__pb2.AnalyzeResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
new file mode 100644
index 000000000000..f5bb6ad5628a
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: spark/connect/commands.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"i\n\x07\x43ommand\x12N\n\x0f\x63reate_function\x18\x01 \x01(\x0b\x32#.spark.connect.CreateScalarFunctionH\x00R\x0e\x63reateFunctionB\x0e\n\x0c\x63ommand_type"\x8f\x04\n\x14\x43reateScalarFunction\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x12P\n\x08language\x18\x02 \x01(\x0e\x32\x34.spark.connect.CreateScalarFunction.FunctionLanguageR\x08language\x12\x1c\n\ttemporary\x18\x03 \x01(\x08R\ttemporary\x12:\n\x0e\x61rgument_types\x18\x04 \x03(\x0b\x32\x13.spark.connect.TypeR\rargumentTypes\x12\x34\n\x0breturn_type\x18\x05 \x01(\x0b\x32\x13.spark.connect.TypeR\nreturnType\x12\x31\n\x13serialized_function\x18\x06 \x01(\x0cH\x00R\x12serializedFunction\x12\'\n\x0eliteral_string\x18\x07 \x01(\tH\x00R\rliteralString"\x8b\x01\n\x10\x46unctionLanguage\x12!\n\x1d\x46UNCTION_LANGUAGE_UNSPECIFIED\x10\x00\x12\x19\n\x15\x46UNCTION_LANGUAGE_SQL\x10\x01\x12\x1c\n\x18\x46UNCTION_LANGUAGE_PYTHON\x10\x02\x12\x1b\n\x17\x46UNCTION_LANGUAGE_SCALA\x10\x03\x42\x15\n\x13\x66unction_definitionBM\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3'
+)
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.commands_pb2", globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = (
+ b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto"
+ )
+ _COMMAND._serialized_start = 74
+ _COMMAND._serialized_end = 179
+ _CREATESCALARFUNCTION._serialized_start = 182
+ _CREATESCALARFUNCTION._serialized_end = 709
+ _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_start = 547
+ _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_end = 686
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
new file mode 100644
index 000000000000..16f3325d4146
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: spark/connect/expressions.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
+from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xd8\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x1a\x97\x10\n\x07Literal\x12\x1a\n\x07\x62oolean\x18\x01 \x01(\x08H\x00R\x07\x62oolean\x12\x10\n\x02i8\x18\x02 \x01(\x05H\x00R\x02i8\x12\x12\n\x03i16\x18\x03 \x01(\x05H\x00R\x03i16\x12\x12\n\x03i32\x18\x05 \x01(\x05H\x00R\x03i32\x12\x12\n\x03i64\x18\x07 \x01(\x03H\x00R\x03i64\x12\x14\n\x04\x66p32\x18\n \x01(\x02H\x00R\x04\x66p32\x12\x14\n\x04\x66p64\x18\x0b \x01(\x01H\x00R\x04\x66p64\x12\x18\n\x06string\x18\x0c \x01(\tH\x00R\x06string\x12\x18\n\x06\x62inary\x18\r \x01(\x0cH\x00R\x06\x62inary\x12\x1e\n\ttimestamp\x18\x0e \x01(\x03H\x00R\ttimestamp\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x14\n\x04time\x18\x11 \x01(\x03H\x00R\x04time\x12l\n\x16interval_year_to_month\x18\x13 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalYearToMonthH\x00R\x13intervalYearToMonth\x12l\n\x16interval_day_to_second\x18\x14 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalDayToSecondH\x00R\x13intervalDayToSecond\x12\x1f\n\nfixed_char\x18\x15 \x01(\tH\x00R\tfixedChar\x12\x46\n\x08var_char\x18\x16 \x01(\x0b\x32).spark.connect.Expression.Literal.VarCharH\x00R\x07varChar\x12#\n\x0c\x66ixed_binary\x18\x17 \x01(\x0cH\x00R\x0b\x66ixedBinary\x12\x45\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x42\n\x06struct\x18\x19 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x1a \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12#\n\x0ctimestamp_tz\x18\x1b \x01(\x03H\x00R\x0btimestampTz\x12\x14\n\x04uuid\x18\x1c \x01(\x0cH\x00R\x04uuid\x12)\n\x04null\x18\x1d \x01(\x0b\x32\x13.spark.connect.TypeH\x00R\x04null\x12<\n\x04list\x18\x1e \x01(\x0b\x32&.spark.connect.Expression.Literal.ListH\x00R\x04list\x12\x39\n\nempty_list\x18\x1f \x01(\x0b\x32\x18.spark.connect.Type.ListH\x00R\temptyList\x12\x36\n\tempty_map\x18 \x01(\x0b\x32\x17.spark.connect.Type.MapH\x00R\x08\x65mptyMap\x12R\n\x0cuser_defined\x18! \x01(\x0b\x32-.spark.connect.Expression.Literal.UserDefinedH\x00R\x0buserDefined\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1a\x37\n\x07VarChar\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12\x16\n\x06length\x18\x02 \x01(\rR\x06length\x1aS\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value\x12\x1c\n\tprecision\x18\x02 \x01(\x05R\tprecision\x12\x14\n\x05scale\x18\x03 \x01(\x05R\x05scale\x1a\xce\x01\n\x03Map\x12M\n\nkey_values\x18\x01 \x03(\x0b\x32..spark.connect.Expression.Literal.Map.KeyValueR\tkeyValues\x1ax\n\x08KeyValue\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value\x1a\x43\n\x13IntervalYearToMonth\x12\x14\n\x05years\x18\x01 \x01(\x05R\x05years\x12\x16\n\x06months\x18\x02 \x01(\x05R\x06months\x1ag\n\x13IntervalDayToSecond\x12\x12\n\x04\x64\x61ys\x18\x01 \x01(\x05R\x04\x64\x61ys\x12\x18\n\x07seconds\x18\x02 \x01(\x05R\x07seconds\x12"\n\x0cmicroseconds\x18\x03 \x01(\x05R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x41\n\x04List\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a`\n\x0bUserDefined\x12%\n\x0etype_reference\x18\x01 \x01(\rR\rtypeReference\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05valueB\x0e\n\x0cliteral_type\x1a+\n\x13UnresolvedAttribute\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x1a\x63\n\x12UnresolvedFunction\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpressionB\x0b\n\texpr_typeBM\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3'
+)
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.expressions_pb2", globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = (
+ b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto"
+ )
+ _EXPRESSION._serialized_start = 105
+ _EXPRESSION._serialized_end = 2753
+ _EXPRESSION_LITERAL._serialized_start = 471
+ _EXPRESSION_LITERAL._serialized_end = 2542
+ _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1769
+ _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1824
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1826
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1909
+ _EXPRESSION_LITERAL_MAP._serialized_start = 1912
+ _EXPRESSION_LITERAL_MAP._serialized_end = 2118
+ _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 1998
+ _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2118
+ _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2120
+ _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2187
+ _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2189
+ _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2292
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 2294
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 2361
+ _EXPRESSION_LITERAL_LIST._serialized_start = 2363
+ _EXPRESSION_LITERAL_LIST._serialized_end = 2428
+ _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2430
+ _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2526
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2544
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2587
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2589
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2688
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2690
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2740
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
new file mode 100644
index 000000000000..b9f74dc23806
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: spark/connect/relations.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from pyspark.sql.connect.proto import (
+ expressions_pb2 as spark_dot_connect_dot_expressions__pb2,
+)
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa6\x04\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05\x66\x65tch\x18\x08 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SqlH\x00R\x03sql\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03Sql\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"z\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a"\n\nNamedTable\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05partsB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd8\x02\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12)\n\x02on\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x02on\x12.\n\x03how\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x03how"\x98\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x13\n\x0fJOIN_TYPE_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x12\n\x0eJOIN_TYPE_ANTI\x10\x05"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"d\n\x05\x46\x65tch\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit\x12\x16\n\x06offset\x18\x03 \x01(\x05R\x06offset"\x8b\x04\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12I\n\rgrouping_sets\x18\x02 \x03(\x0b\x32$.spark.connect.Aggregate.GroupingSetR\x0cgroupingSets\x12<\n\x08measures\x18\x03 \x03(\x0b\x32 .spark.connect.Aggregate.MeasureR\x08measures\x1a]\n\x0bGroupingSet\x12N\n\x15\x61ggregate_expressions\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x1a\x84\x01\n\x07Measure\x12\x46\n\x08\x66unction\x18\x01 \x01(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x08\x66unction\x12\x31\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x06\x66ilter\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42M\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3'
+)
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.relations_pb2", globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = (
+ b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto"
+ )
+ _RELATION._serialized_start = 82
+ _RELATION._serialized_end = 632
+ _UNKNOWN._serialized_start = 634
+ _UNKNOWN._serialized_end = 643
+ _RELATIONCOMMON._serialized_start = 645
+ _RELATIONCOMMON._serialized_end = 716
+ _SQL._serialized_start = 718
+ _SQL._serialized_end = 745
+ _READ._serialized_start = 747
+ _READ._serialized_end = 869
+ _READ_NAMEDTABLE._serialized_start = 822
+ _READ_NAMEDTABLE._serialized_end = 856
+ _PROJECT._serialized_start = 871
+ _PROJECT._serialized_end = 988
+ _FILTER._serialized_start = 990
+ _FILTER._serialized_end = 1102
+ _JOIN._serialized_start = 1105
+ _JOIN._serialized_end = 1449
+ _JOIN_JOINTYPE._serialized_start = 1297
+ _JOIN_JOINTYPE._serialized_end = 1449
+ _UNION._serialized_start = 1452
+ _UNION._serialized_end = 1657
+ _UNION_UNIONTYPE._serialized_start = 1573
+ _UNION_UNIONTYPE._serialized_end = 1657
+ _FETCH._serialized_start = 1659
+ _FETCH._serialized_end = 1759
+ _AGGREGATE._serialized_start = 1762
+ _AGGREGATE._serialized_end = 2285
+ _AGGREGATE_GROUPINGSET._serialized_start = 1959
+ _AGGREGATE_GROUPINGSET._serialized_end = 2052
+ _AGGREGATE_MEASURE._serialized_start = 2055
+ _AGGREGATE_MEASURE._serialized_end = 2187
+ _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2189
+ _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2285
+ _SORT._serialized_start = 2288
+ _SORT._serialized_end = 2790
+ _SORT_SORTFIELD._serialized_start = 2408
+ _SORT_SORTFIELD._serialized_end = 2596
+ _SORT_SORTDIRECTION._serialized_start = 2598
+ _SORT_SORTDIRECTION._serialized_end = 2706
+ _SORT_SORTNULLS._serialized_start = 2708
+ _SORT_SORTNULLS._serialized_end = 2790
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
new file mode 100644
index 000000000000..27247ca4e0c8
--- /dev/null
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -0,0 +1,93 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: spark/connect/types.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xea%\n\x04Type\x12\x31\n\x04\x62ool\x18\x01 \x01(\x0b\x32\x1b.spark.connect.Type.BooleanH\x00R\x04\x62ool\x12(\n\x02i8\x18\x02 \x01(\x0b\x32\x16.spark.connect.Type.I8H\x00R\x02i8\x12+\n\x03i16\x18\x03 \x01(\x0b\x32\x17.spark.connect.Type.I16H\x00R\x03i16\x12+\n\x03i32\x18\x05 \x01(\x0b\x32\x17.spark.connect.Type.I32H\x00R\x03i32\x12+\n\x03i64\x18\x07 \x01(\x0b\x32\x17.spark.connect.Type.I64H\x00R\x03i64\x12.\n\x04\x66p32\x18\n \x01(\x0b\x32\x18.spark.connect.Type.FP32H\x00R\x04\x66p32\x12.\n\x04\x66p64\x18\x0b \x01(\x0b\x32\x18.spark.connect.Type.FP64H\x00R\x04\x66p64\x12\x34\n\x06string\x18\x0c \x01(\x0b\x32\x1a.spark.connect.Type.StringH\x00R\x06string\x12\x34\n\x06\x62inary\x18\r \x01(\x0b\x32\x1a.spark.connect.Type.BinaryH\x00R\x06\x62inary\x12=\n\ttimestamp\x18\x0e \x01(\x0b\x32\x1d.spark.connect.Type.TimestampH\x00R\ttimestamp\x12.\n\x04\x64\x61te\x18\x10 \x01(\x0b\x32\x18.spark.connect.Type.DateH\x00R\x04\x64\x61te\x12.\n\x04time\x18\x11 \x01(\x0b\x32\x18.spark.connect.Type.TimeH\x00R\x04time\x12G\n\rinterval_year\x18\x13 \x01(\x0b\x32 .spark.connect.Type.IntervalYearH\x00R\x0cintervalYear\x12\x44\n\x0cinterval_day\x18\x14 \x01(\x0b\x32\x1f.spark.connect.Type.IntervalDayH\x00R\x0bintervalDay\x12\x44\n\x0ctimestamp_tz\x18\x1d \x01(\x0b\x32\x1f.spark.connect.Type.TimestampTZH\x00R\x0btimestampTz\x12.\n\x04uuid\x18 \x01(\x0b\x32\x18.spark.connect.Type.UUIDH\x00R\x04uuid\x12>\n\nfixed_char\x18\x15 \x01(\x0b\x32\x1d.spark.connect.Type.FixedCharH\x00R\tfixedChar\x12\x37\n\x07varchar\x18\x16 \x01(\x0b\x32\x1b.spark.connect.Type.VarCharH\x00R\x07varchar\x12\x44\n\x0c\x66ixed_binary\x18\x17 \x01(\x0b\x32\x1f.spark.connect.Type.FixedBinaryH\x00R\x0b\x66ixedBinary\x12\x37\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32\x1b.spark.connect.Type.DecimalH\x00R\x07\x64\x65\x63imal\x12\x34\n\x06struct\x18\x19 \x01(\x0b\x32\x1a.spark.connect.Type.StructH\x00R\x06struct\x12.\n\x04list\x18\x1b \x01(\x0b\x32\x18.spark.connect.Type.ListH\x00R\x04list\x12+\n\x03map\x18\x1c \x01(\x0b\x32\x17.spark.connect.Type.MapH\x00R\x03map\x12?\n\x1buser_defined_type_reference\x18\x1f \x01(\rH\x00R\x18userDefinedTypeReference\x1a\x86\x01\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x81\x01\n\x02I8\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x82\x01\n\x03I16\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x82\x01\n\x03I32\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x82\x01\n\x03I64\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04\x46P32\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04\x46P64\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x85\x01\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x85\x01\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x88\x01\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04Time\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x8a\x01\n\x0bTimestampTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x8b\x01\n\x0cIntervalYear\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x8a\x01\n\x0bIntervalDay\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04UUID\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xa0\x01\n\tFixedChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x9e\x01\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xa2\x01\n\x0b\x46ixedBinary\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xba\x01\n\x07\x44\x65\x63imal\x12\x14\n\x05scale\x18\x01 \x01(\x05R\x05scale\x12\x1c\n\tprecision\x18\x02 \x01(\x05R\tprecision\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x04 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xb0\x01\n\x06Struct\x12)\n\x05types\x18\x01 \x03(\x0b\x32\x13.spark.connect.TypeR\x05types\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xac\x01\n\x04List\x12\'\n\x04type\x18\x01 \x01(\x0b\x32\x13.spark.connect.TypeR\x04type\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xd4\x01\n\x03Map\x12%\n\x03key\x18\x01 \x01(\x0b\x32\x13.spark.connect.TypeR\x03key\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x13.spark.connect.TypeR\x05value\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x04 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability"^\n\x0bNullability\x12\x1b\n\x17NULLABILITY_UNSPECIFIED\x10\x00\x12\x18\n\x14NULLABILITY_NULLABLE\x10\x01\x12\x18\n\x14NULLABILITY_REQUIRED\x10\x02\x42\x06\n\x04kindBM\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3'
+)
+
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.types_pb2", globals())
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = (
+ b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto"
+ )
+ _TYPE._serialized_start = 45
+ _TYPE._serialized_end = 4887
+ _TYPE_BOOLEAN._serialized_start = 1366
+ _TYPE_BOOLEAN._serialized_end = 1500
+ _TYPE_I8._serialized_start = 1503
+ _TYPE_I8._serialized_end = 1632
+ _TYPE_I16._serialized_start = 1635
+ _TYPE_I16._serialized_end = 1765
+ _TYPE_I32._serialized_start = 1768
+ _TYPE_I32._serialized_end = 1898
+ _TYPE_I64._serialized_start = 1901
+ _TYPE_I64._serialized_end = 2031
+ _TYPE_FP32._serialized_start = 2034
+ _TYPE_FP32._serialized_end = 2165
+ _TYPE_FP64._serialized_start = 2168
+ _TYPE_FP64._serialized_end = 2299
+ _TYPE_STRING._serialized_start = 2302
+ _TYPE_STRING._serialized_end = 2435
+ _TYPE_BINARY._serialized_start = 2438
+ _TYPE_BINARY._serialized_end = 2571
+ _TYPE_TIMESTAMP._serialized_start = 2574
+ _TYPE_TIMESTAMP._serialized_end = 2710
+ _TYPE_DATE._serialized_start = 2713
+ _TYPE_DATE._serialized_end = 2844
+ _TYPE_TIME._serialized_start = 2847
+ _TYPE_TIME._serialized_end = 2978
+ _TYPE_TIMESTAMPTZ._serialized_start = 2981
+ _TYPE_TIMESTAMPTZ._serialized_end = 3119
+ _TYPE_INTERVALYEAR._serialized_start = 3122
+ _TYPE_INTERVALYEAR._serialized_end = 3261
+ _TYPE_INTERVALDAY._serialized_start = 3264
+ _TYPE_INTERVALDAY._serialized_end = 3402
+ _TYPE_UUID._serialized_start = 3405
+ _TYPE_UUID._serialized_end = 3536
+ _TYPE_FIXEDCHAR._serialized_start = 3539
+ _TYPE_FIXEDCHAR._serialized_end = 3699
+ _TYPE_VARCHAR._serialized_start = 3702
+ _TYPE_VARCHAR._serialized_end = 3860
+ _TYPE_FIXEDBINARY._serialized_start = 3863
+ _TYPE_FIXEDBINARY._serialized_end = 4025
+ _TYPE_DECIMAL._serialized_start = 4028
+ _TYPE_DECIMAL._serialized_end = 4214
+ _TYPE_STRUCT._serialized_start = 4217
+ _TYPE_STRUCT._serialized_end = 4393
+ _TYPE_LIST._serialized_start = 4396
+ _TYPE_LIST._serialized_end = 4568
+ _TYPE_MAP._serialized_start = 4571
+ _TYPE_MAP._serialized_end = 4783
+ _TYPE_NULLABILITY._serialized_start = 4785
+ _TYPE_NULLABILITY._serialized_end = 4879
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
new file mode 100644
index 000000000000..4aac4eed08c5
--- /dev/null
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -0,0 +1,32 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.connect.data_frame import DataFrame
+from pyspark.sql.connect.plan import Read
+
+
+class DataFrameReader:
+ """
+ TODO(SPARK-40539) Achieve parity with PySpark.
+ """
+
+ def __init__(self, client):
+ self._client = client
+
+ def table(self, tableName: str) -> "DataFrame":
+ df = DataFrame.withPlan(Read(tableName), self._client)
+ return df
diff --git a/python/pyspark/sql/tests/connect/__init__.py b/python/pyspark/sql/tests/connect/__init__.py
new file mode 100644
index 000000000000..cce3acad34a4
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/connect/test_column_expressions.py b/python/pyspark/sql/tests/connect/test_column_expressions.py
new file mode 100644
index 000000000000..1f067bf79956
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_column_expressions.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.sql.tests.connect.utils import PlanOnlyTestFixture
+
+import pyspark.sql.connect as c
+import pyspark.sql.connect.plan as p
+import pyspark.sql.connect.column as col
+
+import pyspark.sql.connect.functions as fun
+
+
+class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
+ def test_simple_column_expressions(self):
+ df = c.DataFrame.withPlan(p.Read("table"))
+
+ c1 = df.col_name
+ assert isinstance(c1, col.ColumnRef)
+ c2 = df["col_name"]
+ assert isinstance(c2, col.ColumnRef)
+ c3 = fun.col("col_name")
+ assert isinstance(c3, col.ColumnRef)
+
+ # All Protos should be identical
+ cp1 = c1.to_plan(None)
+ cp2 = c2.to_plan(None)
+ cp3 = c3.to_plan(None)
+
+ assert cp1 is not None
+ assert cp1 == cp2 == cp3
+
+ def test_column_literals(self):
+ df = c.DataFrame.withPlan(p.Read("table"))
+ lit_df = df.select(fun.lit(10))
+ self.assertIsNotNone(lit_df._plan.collect(None))
+
+ self.assertIsNotNone(fun.lit(10).to_plan(None))
+ plan = fun.lit(10).to_plan(None)
+ self.assertIs(plan.literal.i32, 10)
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_column_expressions import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_plan_only.py b/python/pyspark/sql/tests/connect/test_plan_only.py
new file mode 100644
index 000000000000..9e6d30cbe1fd
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_plan_only.py
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.connect import DataFrame
+from pyspark.sql.connect.plan import Read
+from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
+from pyspark.sql.tests.connect.utils.spark_connect_test_utils import PlanOnlyTestFixture
+from pyspark.sql.types import StringType
+
+
+class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
+ """These test cases exercise the interface to the proto plan
+ generation but do not call Spark."""
+
+ def test_simple_project(self):
+ def read_table(x):
+ return DataFrame.withPlan(Read(x), self.connect)
+
+ self.connect.set_hook("readTable", read_table)
+
+ plan = self.connect.readTable(self.tbl_name)._plan.collect(self.connect)
+ self.assertIsNotNone(plan.root, "Root relation must be set")
+ self.assertIsNotNone(plan.root.read)
+
+ def test_simple_udf(self):
+ def udf_mock(*args, **kwargs):
+ return "internal_name"
+
+ self.connect.set_hook("register_udf", udf_mock)
+
+ u = udf(lambda x: "Martin", StringType())
+ self.assertIsNotNone(u)
+ expr = u("ThisCol", "ThatCol", "OtherCol")
+ self.assertTrue(isinstance(expr, UserDefinedFunction))
+ u_plan = expr.to_plan(self.connect)
+ assert u_plan is not None
+
+ def test_all_the_plans(self):
+ def read_table(x):
+ return DataFrame.withPlan(Read(x), self.connect)
+
+ self.connect.set_hook("readTable", read_table)
+
+ df = self.connect.readTable(self.tbl_name)
+ df = df.select(df.col1).filter(df.col2 == 2).sort(df.col3.asc())
+ plan = df._plan.collect(self.connect)
+ self.assertIsNotNone(plan.root, "Root relation must be set")
+ self.assertIsNotNone(plan.root.read)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_plan_only import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_select_ops.py b/python/pyspark/sql/tests/connect/test_select_ops.py
new file mode 100644
index 000000000000..818f82b33e86
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_select_ops.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.sql.tests.connect.utils import PlanOnlyTestFixture
+from pyspark.sql.connect import DataFrame
+from pyspark.sql.connect.functions import col
+from pyspark.sql.connect.plan import Read, InputValidationError
+
+
+class SparkConnectSelectOpsSuite(PlanOnlyTestFixture):
+ def test_select_with_literal(self):
+ df = DataFrame.withPlan(Read("table"))
+ self.assertIsNotNone(df.select(col("name"))._plan.collect())
+ self.assertRaises(InputValidationError, df.select, "name")
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_select_ops import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_spark_connect.py b/python/pyspark/sql/tests/connect/test_spark_connect.py
new file mode 100644
index 000000000000..cac17b7397dc
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_spark_connect.py
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any
+import uuid
+import unittest
+import tempfile
+
+from pyspark.sql import SparkSession, Row
+from pyspark.sql.connect.client import RemoteSparkSession
+from pyspark.sql.connect.function_builder import udf
+from pyspark.testing.utils import ReusedPySparkTestCase
+
+
+class SparkConnectSQLTestCase(ReusedPySparkTestCase):
+ """Parent test fixture class for all Spark Connect related
+ test cases."""
+
+ @classmethod
+ def setUpClass(cls: Any) -> None:
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ cls.hive_available = True
+ # Create the new Spark Session
+ cls.spark = SparkSession(cls.sc)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ cls.df = cls.sc.parallelize(cls.testData).toDF()
+
+ # Load test data
+ cls.spark_connect_test_data()
+
+ @classmethod
+ def spark_connect_test_data(cls: Any) -> None:
+ # Setup Remote Spark Session
+ cls.tbl_name = f"tbl{uuid.uuid4()}".replace("-", "")
+ cls.connect = RemoteSparkSession(user_id="test_user")
+ df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
+ # Since we might create multiple Spark sessions, we need to creata global temporary view
+ # that is specifically maintained in the "global_temp" schema.
+ df.write.saveAsTable(cls.tbl_name)
+
+
+class SparkConnectTests(SparkConnectSQLTestCase):
+ def test_simple_read(self) -> None:
+ """Tests that we can access the Spark Connect GRPC service locally."""
+ df = self.connect.read.table(self.tbl_name)
+ data = df.limit(10).collect()
+ # Check that the limit is applied
+ assert len(data.index) == 10
+
+ def test_simple_udf(self) -> None:
+ def conv_udf(x) -> str:
+ return "Martin"
+
+ u = udf(conv_udf)
+ df = self.connect.read.table(self.tbl_name)
+ result = df.select(u(df.id)).collect()
+ assert result is not None
+
+ def test_simple_explain_string(self) -> None:
+ df = self.connect.read.table(self.tbl_name).limit(10)
+ result = df.explain()
+ assert len(result) > 0
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_spark_connect import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/utils/__init__.py b/python/pyspark/sql/tests/connect/utils/__init__.py
new file mode 100644
index 000000000000..b95812c8a297
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/utils/__init__.py
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.tests.connect.utils.spark_connect_test_utils import ( # noqa: F401
+ PlanOnlyTestFixture, # noqa: F401
+) # noqa: F401
diff --git a/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py b/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py
new file mode 100644
index 000000000000..34bf49db4945
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any, Dict
+import functools
+import unittest
+import uuid
+
+
+class MockRemoteSession:
+ def __init__(self) -> None:
+ self.hooks: Dict[str, Any] = {}
+
+ def set_hook(self, name: str, hook: Any) -> None:
+ self.hooks[name] = hook
+
+ def __getattr__(self, item: str) -> Any:
+ if item not in self.hooks:
+ raise LookupError(f"{item} is not defined as a method hook in MockRemoteSession")
+ return functools.partial(self.hooks[item])
+
+
+class PlanOnlyTestFixture(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls: Any) -> None:
+ cls.connect = MockRemoteSession()
+ cls.tbl_name = f"tbl{uuid.uuid4()}".replace("-", "")
diff --git a/python/run-tests.py b/python/run-tests.py
index d43ed8e96f40..af4c6f1c94be 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -110,6 +110,15 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_
metastore_dir = os.path.join(metastore_dir, str(uuid.uuid4()))
os.mkdir(metastore_dir)
+ # Check if we should enable the SparkConnectPlugin
+ additional_config = []
+ if test_name.startswith("pyspark.sql.tests.connect"):
+ # Adding Spark Connect JAR and Config
+ additional_config += [
+ "--conf",
+ "spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
+ ]
+
# Also override the JVM's temp directory by setting driver and executor options.
java_options = "-Djava.io.tmpdir={0}".format(tmp_dir)
java_options = java_options + " -Dio.netty.tryReflectionSetAccessible=true -Xss4M"
@@ -117,8 +126,10 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_
"--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.sql.warehouse.dir='{0}'".format(metastore_dir),
- "pyspark-shell"
]
+ spark_args += additional_config
+ spark_args += ["pyspark-shell"]
+
env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
output_prefix = get_valid_filename(pyspark_python + "__" + test_name + "__").lstrip("_")