diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 3a9157010e9b..e218b8ff9267 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -242,7 +242,7 @@ jobs: - name: Install Python packages (Python 3.8) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) run: | - python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting + python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting grpcio protobuf python3.8 -m pip list # Run the tests. - name: Run tests @@ -333,6 +333,8 @@ jobs: pyspark-pandas - >- pyspark-pandas-slow + - >- + pyspark-sql-connect env: MODULES_TO_TEST: ${{ matrix.modules }} HADOOP_PROFILE: ${{ inputs.hadoop }} @@ -576,7 +578,7 @@ jobs: # See also https://issues.apache.org/jira/browse/SPARK-38279. python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' python3.9 -m pip install ipython_genutils # See SPARK-38517 - python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' + python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' grpcio protobuf python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 apt-get update -y apt-get install -y ruby ruby-dev diff --git a/assembly/pom.xml b/assembly/pom.xml index f37edcd7e49f..218bf3679504 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -74,6 +74,11 @@ spark-repl_${scala.binary.version} ${project.version} + + org.apache.spark + spark-connect_${scala.binary.version} + ${project.version} + + + + 4.0.0 + + org.apache.spark + spark-parent_2.12 + 3.4.0-SNAPSHOT + ../pom.xml + + + spark-connect_2.12 + jar + Spark Project Connect + https://spark.apache.org/ + + connect + 3.21.1 + 31.0.1-jre + 1.47.0 + 6.0.53 + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + provided + + + com.google.guava + guava + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + provided + + + com.google.guava + guava + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + com.google.guava + guava + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-tags_${scala.binary.version} + ${project.version} + provided + + + com.google.guava + guava + + + + + + com.google.guava + guava + ${guava.version} + compile + + + com.google.guava + failureaccess + 1.0.1 + + + com.google.protobuf + protobuf-java + ${protobuf.version} + compile + + + io.grpc + grpc-netty-shaded + ${io.grpc.version} + + + io.grpc + grpc-protobuf + ${io.grpc.version} + + + io.grpc + grpc-services + ${io.grpc.version} + + + io.grpc + grpc-stub + ${io.grpc.version} + + + org.apache.tomcat + annotations-api + ${tomcat.annotations.api.version} + provided + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + + + + + + kr.motd.maven + os-maven-plugin + 1.6.2 + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-sources + generate-sources + + add-source + + + + src/main/scala-${scala.binary.version} + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/gen-java + + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:${io.grpc.version}:exe:${os.detected.classifier} + src/main/protobuf + + + + + compile + compile-custom + test-compile + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + com.google.guava:* + io.grpc:*: + com.google.protobuf:* + + + + + com.google.common + ${spark.shade.packageName}.connect.guava + + com.google.common.** + + + + com.google.thirdparty + ${spark.shade.packageName}.connect.guava + + com.google.thirdparty.** + + + + com.google.protobuf + ${spark.shade.packageName}.connect.protobuf + + com.google.protobuf.** + + + + io.grpc + ${spark.shade.packageName}.connect.grpc + + + + + + + diff --git a/connect/src/main/buf.gen.yaml b/connect/src/main/buf.gen.yaml new file mode 100644 index 000000000000..01e31d3c8b4c --- /dev/null +++ b/connect/src/main/buf.gen.yaml @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +plugins: + - remote: buf.build/protocolbuffers/plugins/cpp:v3.20.0-1 + out: gen/proto/cpp + - remote: buf.build/protocolbuffers/plugins/csharp:v3.20.0-1 + out: gen/proto/csharp + - remote: buf.build/protocolbuffers/plugins/java:v3.20.0-1 + out: gen/proto/java + - remote: buf.build/protocolbuffers/plugins/python:v3.20.0-1 + out: gen/proto/python + - remote: buf.build/grpc/plugins/python:v1.47.0-1 + out: gen/proto/python + - remote: buf.build/protocolbuffers/plugins/go:v1.28.0-1 + out: gen/proto/go + opt: + - paths=source_relative + - remote: buf.build/grpc/plugins/go:v1.2.0-1 + out: gen/proto/go + opt: + - paths=source_relative + - require_unimplemented_servers=false + - remote: buf.build/grpc/plugins/ruby:v1.47.0-1 + out: gen/proto/ruby + - remote: buf.build/protocolbuffers/plugins/ruby:v21.2.0-1 + out: gen/proto/ruby diff --git a/connect/src/main/buf.work.yaml b/connect/src/main/buf.work.yaml new file mode 100644 index 000000000000..a02dead420cd --- /dev/null +++ b/connect/src/main/buf.work.yaml @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +directories: + - protobuf diff --git a/connect/src/main/protobuf/buf.yaml b/connect/src/main/protobuf/buf.yaml new file mode 100644 index 000000000000..496e97af3fa0 --- /dev/null +++ b/connect/src/main/protobuf/buf.yaml @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +breaking: + use: + - FILE +lint: + use: + - DEFAULT diff --git a/connect/src/main/protobuf/spark/connect/base.proto b/connect/src/main/protobuf/spark/connect/base.proto new file mode 100644 index 000000000000..450f60d6aa5d --- /dev/null +++ b/connect/src/main/protobuf/spark/connect/base.proto @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +package spark.connect; + +import "spark/connect/commands.proto"; +import "spark/connect/relations.proto"; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +// A [[Plan]] is the structure that carries the runtime information for the execution from the +// client to the server. A [[Plan]] can either be of the type [[Relation]] which is a reference +// to the underlying logical plan or it can be of the [[Command]] type that is used to execute +// commands on the server. +message Plan { + oneof op_type { + Relation root = 1; + Command command = 2; + } +} + +// A request to be executed by the service. +message Request { + // The client_id is set by the client to be able to collate streaming responses from + // different queries. + string client_id = 1; + // User context + UserContext user_context = 2; + // The logical plan to be executed / analyzed. + Plan plan = 3; + + // User Context is used to refer to one particular user session that is executing + // queries in the backend. + message UserContext { + string user_id = 1; + string user_name = 2; + } +} + +// The response of a query, can be one or more for each request. Responses belonging to the +// same input query, carry the same `client_id`. +message Response { + string client_id = 1; + + // Result type + oneof result_type { + ArrowBatch batch = 2; + CSVBatch csv_batch = 3; + } + + // Metrics for the query execution. Typically, this field is only present in the last + // batch of results and then represent the overall state of the query execution. + Metrics metrics = 4; + + // Batch results of metrics. + message ArrowBatch { + int64 row_count = 1; + int64 uncompressed_bytes = 2; + int64 compressed_bytes = 3; + bytes data = 4; + bytes schema = 5; + } + + message CSVBatch { + int64 row_count = 1; + string data = 2; + } + + message Metrics { + + repeated MetricObject metrics = 1; + + message MetricObject { + string name = 1; + int64 plan_id = 2; + int64 parent = 3; + map execution_metrics = 4; + } + + message MetricValue { + string name = 1; + int64 value = 2; + string metric_type = 3; + } + } +} + +// Response to performing analysis of the query. Contains relevant metadata to be able to +// reason about the performance. +message AnalyzeResponse { + string client_id = 1; + repeated string column_names = 2; + repeated string column_types = 3; + + // The extended explain string as produced by Spark. + string explain_string = 4; +} + +// Main interface for the SparkConnect service. +service SparkConnectService { + + // Executes a request that contains the query and returns a stream of [[Response]]. + rpc ExecutePlan(Request) returns (stream Response) {} + + // Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query. + rpc AnalyzePlan(Request) returns (AnalyzeResponse) {} +} + diff --git a/connect/src/main/protobuf/spark/connect/commands.proto b/connect/src/main/protobuf/spark/connect/commands.proto new file mode 100644 index 000000000000..425857b842e5 --- /dev/null +++ b/connect/src/main/protobuf/spark/connect/commands.proto @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +import "spark/connect/types.proto"; + +package spark.connect; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +// A [[Command]] is an operation that is executed by the server that does not directly consume or +// produce a relational result. +message Command { + oneof command_type { + CreateScalarFunction create_function = 1; + } +} + +// Simple message that is used to create a scalar function based on the provided function body. +// +// This message is used to register for example a Python UDF in the session catalog by providing +// the serialized method body. +// +// TODO(SPARK-40532) It is required to add the interpreter / language version to the command +// parameters. +message CreateScalarFunction { + // Fully qualified name of the function including the catalog / schema names. + repeated string parts = 1; + FunctionLanguage language = 2; + bool temporary = 3; + repeated Type argument_types = 4; + Type return_type = 5; + + // How the function body is defined: + oneof function_definition { + // As a raw string serialized: + bytes serialized_function = 6; + // As a code literal + string literal_string = 7; + } + + enum FunctionLanguage { + FUNCTION_LANGUAGE_UNSPECIFIED = 0; + FUNCTION_LANGUAGE_SQL = 1; + FUNCTION_LANGUAGE_PYTHON = 2; + FUNCTION_LANGUAGE_SCALA = 3; + } +} diff --git a/connect/src/main/protobuf/spark/connect/expressions.proto b/connect/src/main/protobuf/spark/connect/expressions.proto new file mode 100644 index 000000000000..6b72a646623c --- /dev/null +++ b/connect/src/main/protobuf/spark/connect/expressions.proto @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +import "spark/connect/types.proto"; +import "google/protobuf/any.proto"; + +package spark.connect; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +// Expression used to refer to fields, functions and similar. This can be used everywhere +// expressions in SQL appear. +message Expression { + + oneof expr_type { + Literal literal = 1; + UnresolvedAttribute unresolved_attribute = 2; + UnresolvedFunction unresolved_function = 3; + ExpressionString expression_string = 4; + } + + message Literal { + oneof literal_type { + bool boolean = 1; + int32 i8 = 2; + int32 i16 = 3; + int32 i32 = 5; + int64 i64 = 7; + float fp32 = 10; + double fp64 = 11; + string string = 12; + bytes binary = 13; + // Timestamp in units of microseconds since the UNIX epoch. + int64 timestamp = 14; + // Date in units of days since the UNIX epoch. + int32 date = 16; + // Time in units of microseconds past midnight + int64 time = 17; + IntervalYearToMonth interval_year_to_month = 19; + IntervalDayToSecond interval_day_to_second = 20; + string fixed_char = 21; + VarChar var_char = 22; + bytes fixed_binary = 23; + Decimal decimal = 24; + Struct struct = 25; + Map map = 26; + // Timestamp in units of microseconds since the UNIX epoch. + int64 timestamp_tz = 27; + bytes uuid = 28; + Type null = 29; // a typed null literal + List list = 30; + Type.List empty_list = 31; + Type.Map empty_map = 32; + UserDefined user_defined = 33; + } + + // whether the literal type should be treated as a nullable type. Applies to + // all members of union other than the Typed null (which should directly + // declare nullability). + bool nullable = 50; + + // optionally points to a type_variation_anchor defined in this plan. + // Applies to all members of union other than the Typed null (which should + // directly declare the type variation). + uint32 type_variation_reference = 51; + + message VarChar { + string value = 1; + uint32 length = 2; + } + + message Decimal { + // little-endian twos-complement integer representation of complete value + // (ignoring precision) Always 16 bytes in length + bytes value = 1; + // The maximum number of digits allowed in the value. + // the maximum precision is 38. + int32 precision = 2; + // declared scale of decimal literal + int32 scale = 3; + } + + message Map { + message KeyValue { + Literal key = 1; + Literal value = 2; + } + + repeated KeyValue key_values = 1; + } + + message IntervalYearToMonth { + int32 years = 1; + int32 months = 2; + } + + message IntervalDayToSecond { + int32 days = 1; + int32 seconds = 2; + int32 microseconds = 3; + } + + message Struct { + // A possibly heterogeneously typed list of literals + repeated Literal fields = 1; + } + + message List { + // A homogeneously typed list of literals + repeated Literal values = 1; + } + + message UserDefined { + // points to a type_anchor defined in this plan + uint32 type_reference = 1; + + // the value of the literal, serialized using some type-specific + // protobuf message + google.protobuf.Any value = 2; + } + } + + // An unresolved attribute that is not explicitly bound to a specific column, but the column + // is resolved during analysis by name. + message UnresolvedAttribute { + repeated string parts = 1; + } + + // An unresolved function is not explicitly bound to one explicit function, but the function + // is resolved during analysis following Sparks name resolution rules. + message UnresolvedFunction { + repeated string parts = 1; + repeated Expression arguments = 2; + } + + // Expression as string. + message ExpressionString { + string expression = 1; + } + +} diff --git a/connect/src/main/protobuf/spark/connect/relations.proto b/connect/src/main/protobuf/spark/connect/relations.proto new file mode 100644 index 000000000000..adbe178da992 --- /dev/null +++ b/connect/src/main/protobuf/spark/connect/relations.proto @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +package spark.connect; + +import "spark/connect/expressions.proto"; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +// The main [[Relation]] type. Fundamentally, a relation is a typed container +// that has exactly one explicit relation type set. +// +// When adding new relation types, they have to be registered here. +message Relation { + RelationCommon common = 1; + oneof rel_type { + Read read = 2; + Project project = 3; + Filter filter = 4; + Join join = 5; + Union union = 6; + Sort sort = 7; + Fetch fetch = 8; + Aggregate aggregate = 9; + SQL sql = 10; + + Unknown unknown = 999; + } +} + +// Used for testing purposes only. +message Unknown {} + +// Common metadata of all relations. +message RelationCommon { + string source_info = 1; + string alias = 2; +} + +// Relation that uses a SQL query to generate the output. +message SQL { + string query = 1; +} + +// Relation that reads from a file / table or other data source. Does not have additional +// inputs. +message Read { + oneof read_type { + NamedTable named_table = 1; + } + + message NamedTable { + repeated string parts = 1; + } +} + +// Projection of a bag of expressions for a given input relation. +// +// The input relation must be specified. +// The projected expression can be an arbitrary expression. +message Project { + Relation input = 1; + repeated Expression expressions = 3; +} + +// Relation that applies a boolean expression `condition` on each row of `input` to produce +// the output result. +message Filter { + Relation input = 1; + Expression condition = 2; +} + +// Relation of type [[Join]]. +// +// `left` and `right` must be present. +message Join { + Relation left = 1; + Relation right = 2; + Expression on = 3; + JoinType how = 4; + + enum JoinType { + JOIN_TYPE_UNSPECIFIED = 0; + JOIN_TYPE_INNER = 1; + JOIN_TYPE_OUTER = 2; + JOIN_TYPE_LEFT_OUTER = 3; + JOIN_TYPE_RIGHT_OUTER = 4; + JOIN_TYPE_ANTI = 5; + } +} + +// Relation of type [[Union]], at least one input must be set. +message Union { + repeated Relation inputs = 1; + UnionType union_type = 2; + + enum UnionType { + UNION_TYPE_UNSPECIFIED = 0; + UNION_TYPE_DISTINCT = 1; + UNION_TYPE_ALL = 2; + } +} + +// Relation of type [[Fetch]] that is used to read `limit` / `offset` rows from the input relation. +message Fetch { + Relation input = 1; + int32 limit = 2; + int32 offset = 3; +} + +// Relation of type [[Aggregate]]. +message Aggregate { + Relation input = 1; + + // Grouping sets are used in rollups + repeated GroupingSet grouping_sets = 2; + + // Measures + repeated Measure measures = 3; + + message GroupingSet { + repeated Expression aggregate_expressions = 1; + } + + message Measure { + AggregateFunction function = 1; + // Conditional filter for SUM(x FILTER WHERE x < 10) + Expression filter = 2; + } + + message AggregateFunction { + string name = 1; + repeated Expression arguments = 2; + } +} + +// Relation of type [[Sort]]. +message Sort { + Relation input = 1; + repeated SortField sort_fields = 2; + + message SortField { + Expression expression = 1; + SortDirection direction = 2; + SortNulls nulls = 3; + } + + enum SortDirection { + SORT_DIRECTION_UNSPECIFIED = 0; + SORT_DIRECTION_ASCENDING = 1; + SORT_DIRECTION_DESCENDING = 2; + } + + enum SortNulls { + SORT_NULLS_UNSPECIFIED = 0; + SORT_NULLS_FIRST = 1; + SORT_NULLS_LAST = 2; + } +} diff --git a/connect/src/main/protobuf/spark/connect/types.proto b/connect/src/main/protobuf/spark/connect/types.proto new file mode 100644 index 000000000000..c46afa2afc65 --- /dev/null +++ b/connect/src/main/protobuf/spark/connect/types.proto @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +package spark.connect; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +// This message describes the logical [[Type]] of something. It does not carry the value +// itself but only describes it. +message Type { + oneof kind { + Boolean bool = 1; + I8 i8 = 2; + I16 i16 = 3; + I32 i32 = 5; + I64 i64 = 7; + FP32 fp32 = 10; + FP64 fp64 = 11; + String string = 12; + Binary binary = 13; + Timestamp timestamp = 14; + Date date = 16; + Time time = 17; + IntervalYear interval_year = 19; + IntervalDay interval_day = 20; + TimestampTZ timestamp_tz = 29; + UUID uuid = 32; + + FixedChar fixed_char = 21; + VarChar varchar = 22; + FixedBinary fixed_binary = 23; + Decimal decimal = 24; + + Struct struct = 25; + List list = 27; + Map map = 28; + + uint32 user_defined_type_reference = 31; + } + + enum Nullability { + NULLABILITY_UNSPECIFIED = 0; + NULLABILITY_NULLABLE = 1; + NULLABILITY_REQUIRED = 2; + } + + message Boolean { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I8 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I16 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I32 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I64 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message FP32 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message FP64 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message String { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Binary { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Timestamp { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Date { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Time { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message TimestampTZ { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message IntervalYear { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message IntervalDay { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message UUID { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + // Start compound types. + message FixedChar { + int32 length = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message VarChar { + int32 length = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message FixedBinary { + int32 length = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message Decimal { + int32 scale = 1; + int32 precision = 2; + uint32 type_variation_reference = 3; + Nullability nullability = 4; + } + + message Struct { + repeated Type types = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message List { + Type type = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message Map { + Type key = 1; + Type value = 2; + uint32 type_variation_reference = 3; + Nullability nullability = 4; + } +} diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala new file mode 100644 index 000000000000..d262947015cb --- /dev/null +++ b/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Unstable +import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} +import org.apache.spark.sql.connect.service.SparkConnectService + +/** + * This is the main entry point for Spark Connect. + * + * To decouple the build of Spark Connect and its dependencies from the core of Spark, we + * implement it as a Driver Plugin. To enable Spark Connect, simply make sure that the appropriate + * JAR is available in the CLASSPATH and the driver plugin is configured to load this class. + */ +@Unstable +class SparkConnectPlugin extends SparkPlugin { + + /** + * Return the plugin's driver-side component. + * + * @return The driver-side component. + */ + override def driverPlugin(): DriverPlugin = new DriverPlugin { + + override def init( + sc: SparkContext, + pluginContext: PluginContext): util.Map[String, String] = { + SparkConnectService.start() + Map.empty[String, String].asJava + } + + override def shutdown(): Unit = { + SparkConnectService.stop() + } + } + + override def executorPlugin(): ExecutorPlugin = null +} diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala new file mode 100644 index 000000000000..865c7543609b --- /dev/null +++ b/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.command + +import scala.collection.JavaConverters._ + +import com.google.common.collect.{Lists, Maps} + +import org.apache.spark.annotation.{Since, Unstable} +import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} +import org.apache.spark.connect.proto +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.types.StringType + + +@Unstable +@Since("3.4.0") +class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) { + + lazy val pythonVersion = + sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) + + def process(): Unit = { + command.getCommandTypeCase match { + case proto.Command.CommandTypeCase.CREATE_FUNCTION => + handleCreateScalarFunction(command.getCreateFunction) + case _ => throw new UnsupportedOperationException(s"$command not supported.") + } + } + + /** + * This is a helper function that registers a new Python function in the SparkSession. + * + * Right now this function is very rudimentary and bare-bones just to showcase how it + * is possible to remotely serialize a Python function and execute it on the Spark cluster. + * If the Python version on the client and server diverge, the execution of the function that + * is serialized will most likely fail. + * + * @param cf + */ + def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = { + val function = SimplePythonFunction( + cf.getSerializedFunction.toByteArray, + Maps.newHashMap(), + Lists.newArrayList(), + pythonVersion, + "3.9", // TODO(SPARK-40532) This needs to be an actual Python version. + Lists.newArrayList(), + null) + + val udf = UserDefinedPythonFunction( + cf.getPartsList.asScala.head, + function, + StringType, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = false) + + session.udf.registerPython(cf.getPartsList.asScala.head, udf) + } + +} diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala new file mode 100644 index 000000000000..ee2e05a1e606 --- /dev/null +++ b/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Since, Unstable} +import org.apache.spark.connect.proto +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.{expressions, plans} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.types._ + +final case class InvalidPlanInput( + private val message: String = "", + private val cause: Throwable = None.orNull) + extends Exception(message, cause) + +@Unstable +@Since("3.4.0") +class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { + + def transform(): LogicalPlan = { + transformRelation(plan) + } + + // The root of the query plan is a relation and we apply the transformations to it. + private def transformRelation(rel: proto.Relation): LogicalPlan = { + val common = if (rel.hasCommon) { + Some(rel.getCommon) + } else { + None + } + + rel.getRelTypeCase match { + case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead, common) + case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject, common) + case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) + case proto.Relation.RelTypeCase.FETCH => transformFetch(rel.getFetch) + case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin) + case proto.Relation.RelTypeCase.UNION => transformUnion(rel.getUnion) + case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) + case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) + case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) + case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => + throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") + case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") + } + } + + private def transformSql(sql: proto.SQL): LogicalPlan = { + session.sessionState.sqlParser.parsePlan(sql.getQuery) + } + + private def transformReadRel( + rel: proto.Read, + common: Option[proto.RelationCommon]): LogicalPlan = { + val baseRelation = rel.getReadTypeCase match { + case proto.Read.ReadTypeCase.NAMED_TABLE => + val child = UnresolvedRelation(rel.getNamedTable.getPartsList.asScala.toSeq) + if (common.nonEmpty && common.get.getAlias.nonEmpty) { + SubqueryAlias(identifier = common.get.getAlias, child = child) + } else { + child + } + case _ => throw InvalidPlanInput() + } + baseRelation + } + + private def transformFilter(rel: proto.Filter): LogicalPlan = { + assert(rel.hasInput) + val baseRel = transformRelation(rel.getInput) + logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel) + } + + private def transformProject( + rel: proto.Project, + common: Option[proto.RelationCommon]): LogicalPlan = { + val baseRel = transformRelation(rel.getInput) + val projection = if (rel.getExpressionsCount == 0) { + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } + val project = logical.Project(projectList = projection.toSeq, child = baseRel) + if (common.nonEmpty && common.get.getAlias.nonEmpty) { + logical.SubqueryAlias(identifier = common.get.getAlias, child = project) + } else { + project + } + } + + private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = { + UnresolvedAttribute(exp.getUnresolvedAttribute.getPartsList.asScala.toSeq) + } + + private def transformExpression(exp: proto.Expression): Expression = { + exp.getExprTypeCase match { + case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral) + case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE => + transformUnresolvedExpression(exp) + case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => + transformScalarFunction(exp.getUnresolvedFunction) + case _ => throw InvalidPlanInput() + } + } + + /** + * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. + * + * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, + * Duration, Period. + * @param lit + * @return + * Expression + */ + private def transformLiteral(lit: proto.Expression.Literal): Expression = { + lit.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => expressions.Literal(lit.getBoolean) + case proto.Expression.Literal.LiteralTypeCase.I8 => expressions.Literal(lit.getI8, ByteType) + case proto.Expression.Literal.LiteralTypeCase.I16 => + expressions.Literal(lit.getI16, ShortType) + case proto.Expression.Literal.LiteralTypeCase.I32 => expressions.Literal(lit.getI32) + case proto.Expression.Literal.LiteralTypeCase.I64 => expressions.Literal(lit.getI64) + case proto.Expression.Literal.LiteralTypeCase.FP32 => + expressions.Literal(lit.getFp32, FloatType) + case proto.Expression.Literal.LiteralTypeCase.FP64 => + expressions.Literal(lit.getFp64, DoubleType) + case proto.Expression.Literal.LiteralTypeCase.STRING => expressions.Literal(lit.getString) + case proto.Expression.Literal.LiteralTypeCase.BINARY => + expressions.Literal(lit.getBinary, BinaryType) + // Microseconds since unix epoch. + case proto.Expression.Literal.LiteralTypeCase.TIME => + expressions.Literal(lit.getTime, TimestampType) + // Days since UNIX epoch. + case proto.Expression.Literal.LiteralTypeCase.DATE => + expressions.Literal(lit.getDate, DateType) + case _ => throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") + } + } + + private def transformFetch(limit: proto.Fetch): LogicalPlan = { + logical.Limit( + child = transformRelation(limit.getInput), + limitExpr = expressions.Literal(limit.getLimit, IntegerType)) + } + + private def lookupFunction(name: String, args: Seq[Expression]): Expression = { + UnresolvedFunction(Seq(name), args, isDistinct = false) + } + + /** + * Translates a scalar function from proto to the Catalyst expression. + * + * TODO(SPARK-40546) We need to homogenize the function names for binary operators. + * + * @param fun Proto representation of the function call. + * @return + */ + private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { + val funName = fun.getPartsList.asScala.mkString(".") + funName match { + case "gt" => + assert(fun.getArgumentsCount == 2, "`gt` function must have two arguments.") + expressions.GreaterThan( + transformExpression(fun.getArguments(0)), + transformExpression(fun.getArguments(1))) + case "eq" => + assert(fun.getArgumentsCount == 2, "`eq` function must have two arguments.") + expressions.EqualTo( + transformExpression(fun.getArguments(0)), + transformExpression(fun.getArguments(1))) + case _ => + lookupFunction(funName, fun.getArgumentsList.asScala.map(transformExpression).toSeq) + } + } + + private def transformUnion(u: proto.Union): LogicalPlan = { + assert(u.getInputsCount == 2, "Union must have 2 inputs") + val plan = logical.Union(transformRelation(u.getInputs(0)), transformRelation(u.getInputs(1))) + + u.getUnionType match { + case proto.Union.UnionType.UNION_TYPE_DISTINCT => logical.Distinct(plan) + case proto.Union.UnionType.UNION_TYPE_ALL => plan + case _ => + throw InvalidPlanInput(s"Unsupported set operation ${u.getUnionTypeValue}") + } + } + + private def transformJoin(rel: proto.Join): LogicalPlan = { + assert(rel.hasLeft && rel.hasRight, "Both join sides must be present") + logical.Join( + left = transformRelation(rel.getLeft), + right = transformRelation(rel.getRight), + // TODO(SPARK-40534) Support additional join types and configuration. + joinType = plans.Inner, + condition = Some(transformExpression(rel.getOn)), + hint = logical.JoinHint.NONE) + } + + private def transformSort(rel: proto.Sort): LogicalPlan = { + assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.") + logical.Sort( + child = transformRelation(rel.getInput), + global = true, + order = rel.getSortFieldsList.asScala.map(transformSortOrderExpression).toSeq) + } + + private def transformSortOrderExpression(so: proto.Sort.SortField): expressions.SortOrder = { + expressions.SortOrder( + child = transformUnresolvedExpression(so.getExpression), + direction = so.getDirection match { + case proto.Sort.SortDirection.SORT_DIRECTION_DESCENDING => expressions.Descending + case _ => expressions.Ascending + }, + nullOrdering = so.getNulls match { + case proto.Sort.SortNulls.SORT_NULLS_LAST => expressions.NullsLast + case _ => expressions.NullsFirst + }, + sameOrderExpressions = Seq.empty) + } + + private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { + assert(rel.hasInput) + assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported") + + val groupingSet = rel.getGroupingSetsList.asScala.take(1) + val ge = groupingSet + .flatMap(f => f.getAggregateExpressionsList.asScala) + .map(transformExpression) + .map { + case x @ UnresolvedAttribute(_) => x + case x => UnresolvedAlias(x) + } + + logical.Aggregate( + child = transformRelation(rel.getInput), + groupingExpressions = ge.toSeq, + aggregateExpressions = + (rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ ge).toSeq) + } + + private def transformAggregateExpression( + exp: proto.Aggregate.Measure): expressions.NamedExpression = { + val fun = exp.getFunction.getName + UnresolvedAlias( + UnresolvedFunction( + name = fun, + arguments = exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq, + isDistinct = false)) + } + +} diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala new file mode 100644 index 000000000000..3357ad26f6c9 --- /dev/null +++ b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import com.google.common.base.Ticker +import com.google.common.cache.CacheBuilder +import io.grpc.{Server, Status} +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder +import io.grpc.protobuf.services.ProtoReflectionService +import io.grpc.stub.StreamObserver + +import org.apache.spark.SparkEnv +import org.apache.spark.annotation.{Since, Unstable} +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, SparkConnectServiceGrpc} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.execution.ExtendedMode + +/** + * The SparkConnectService implementation. + * + * This class implements the service stub from the generated code of GRPC. + * + * @param debug + * delegates debug behavior to the handlers. + */ +@Unstable +@Since("3.4.0") +class SparkConnectService( + debug: Boolean) + extends SparkConnectServiceGrpc.SparkConnectServiceImplBase + with Logging { + + /** + * This is the main entry method for Spark Connect and all calls to execute a plan. + * + * The plan execution is delegated to the [[SparkConnectStreamHandler]]. All error handling + * should be directly implemented in the deferred implementation. But this method catches + * generic errors. + * + * @param request + * @param responseObserver + */ + override def executePlan(request: Request, responseObserver: StreamObserver[Response]): Unit = { + try { + new SparkConnectStreamHandler(responseObserver).handle(request) + } catch { + case e: Throwable => + log.error("Error executing plan.", e) + responseObserver.onError( + Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) + } + } + + /** + * Analyze a plan to provide metadata and debugging information. + * + * This method is called to generate the explain plan for a SparkConnect plan. In its simplest + * implementation, the plan that is generated by the [[SparkConnectPlanner]] is used to build a + * [[Dataset]] and derive the explain string from the query execution details. + * + * Errors during planning are returned via the [[StreamObserver]] interface. + * + * @param request + * @param responseObserver + */ + override def analyzePlan( + request: Request, + responseObserver: StreamObserver[AnalyzeResponse]): Unit = { + try { + val session = + SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session + + val logicalPlan = request.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.ROOT => + new SparkConnectPlanner(request.getPlan.getRoot, session).transform() + case _ => + responseObserver.onError( + new UnsupportedOperationException( + s"${request.getPlan.getOpTypeCase} not supported for analysis.")) + return + } + val ds = Dataset.ofRows(session, logicalPlan) + val explainString = ds.queryExecution.explainString(ExtendedMode) + + val resp = proto.AnalyzeResponse + .newBuilder() + .setExplainString(explainString) + .setClientId(request.getClientId) + + resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava) + resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava) + responseObserver.onNext(resp.build()) + responseObserver.onCompleted() + } catch { + case e: Throwable => + log.error("Error analyzing plan.", e) + responseObserver.onError( + Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) + } + } +} + +/** + * Object used for referring to SparkSessions in the SessionCache. + * + * @param userId + * @param session + */ +@Unstable +@Since("3.4.0") +private[connect] case class SessionHolder(userId: String, session: SparkSession) + +/** + * Static instance of the SparkConnectService. + * + * Used to start the overall SparkConnect service and provides global state to manage the + * different SparkSession from different users connecting to the cluster. + */ +@Unstable +@Since("3.4.0") +object SparkConnectService { + + private val CACHE_SIZE = 100 + + private val CACHE_TIMEOUT_SECONDS = 3600 + + // Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a + // different or complex type easily. + private type SessionCacheKey = String; + + private var server: Server = _ + + private val userSessionMapping = + cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, SessionHolder]() + + // Simple builder for creating the cache of Sessions. + private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): CacheBuilder[Object, Object] = { + var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker()) + if (cacheSize >= 0) { + cacheBuilder = cacheBuilder.maximumSize(cacheSize) + } + if (timeoutSeconds >= 0) { + cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS) + } + cacheBuilder + } + + /** + * Based on the `key` find or create a new SparkSession. + */ + private[connect] def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = { + userSessionMapping.get( + key, + () => { + SessionHolder(key, newIsolatedSession()) + }) + } + + private def newIsolatedSession(): SparkSession = { + SparkSession.active.newSession() + } + + /** + * Starts the GRPC Serivce. + * + * TODO(SPARK-40536) Make port number configurable. + */ + def startGRPCService(): Unit = { + val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) + val port = 15002 + val sb = NettyServerBuilder + .forPort(port) + .addService(new SparkConnectService(debugMode)) + + // If debug mode is configured, load the ProtoReflection service so that tools like + // grpcurl can introspect the API for debugging. + if (debugMode) { + sb.addService(ProtoReflectionService.newInstance()) + } + server = sb.build + server.start() + } + + // Starts the service + def start(): Unit = { + startGRPCService() + } + + def stop(): Unit = { + if (server != null) { + server.shutdownNow() + } + } +} + diff --git a/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala new file mode 100644 index 000000000000..52b807f63bb0 --- /dev/null +++ b/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import scala.collection.JavaConverters._ + +import com.google.protobuf.ByteString +import io.grpc.stub.StreamObserver + +import org.apache.spark.annotation.{Since, Unstable} +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{Request, Response} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.ArrowUtils + + +@Unstable +@Since("3.4.0") +class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { + + def handle(v: Request): Unit = { + val session = + SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session + v.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v) + case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v) + case _ => + throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.") + } + } + + def handlePlan(session: SparkSession, request: proto.Request): Unit = { + // Extract the plan from the request and convert it to a logical plan + val planner = new SparkConnectPlanner(request.getPlan.getRoot, session) + val rows = + Dataset.ofRows(session, planner.transform()) + processRows(request.getClientId, rows) + } + + private def processRows(clientId: String, rows: DataFrame) = { + val timeZoneId = SQLConf.get.sessionLocalTimeZone + val schema = + ByteString.copyFrom(ArrowUtils.toArrowSchema(rows.schema, timeZoneId).toByteArray) + + val textSchema = rows.schema.fields.map(f => f.name).mkString("|") + val data = rows.collect().map(x => x.toSeq.mkString("|")).mkString("\n") + val bbb = proto.Response.CSVBatch.newBuilder + .setRowCount(-1) + .setData(textSchema ++ "\n" ++ data) + .build() + val response = proto.Response.newBuilder().setClientId(clientId).setCsvBatch(bbb).build() + + // Send all the data + responseObserver.onNext(response) + responseObserver.onNext(sendMetricsToResponse(clientId, rows)) + responseObserver.onCompleted() + } + + def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = { + // Send a last batch with the metrics + Response + .newBuilder() + .setClientId(clientId) + .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan)) + .build() + } + + def handleCommand(session: SparkSession, request: Request): Unit = { + val command = request.getPlan.getCommand + val planner = new SparkConnectCommandPlanner(session, command) + planner.process() + responseObserver.onCompleted() + } +} + +object MetricGenerator extends AdaptiveSparkPlanHelper { + def buildMetrics(p: SparkPlan): Response.Metrics = { + val b = Response.Metrics.newBuilder + b.addAllMetrics(transformPlan(p, p.id).asJava) + b.build() + } + + def transformChildren(p: SparkPlan): Seq[Response.Metrics.MetricObject] = { + allChildren(p).flatMap(c => transformPlan(c, p.id)) + } + + def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { + case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) + case s: QueryStageExec => Seq(s.plan) + case _ => p.children + } + + def transformPlan(p: SparkPlan, parentId: Int): Seq[Response.Metrics.MetricObject] = { + val mv = p.metrics.map(m => + m._1 -> Response.Metrics.MetricValue.newBuilder + .setName(m._2.name.getOrElse("")) + .setValue(m._2.value) + .setMetricType(m._2.metricType) + .build()) + val mo = Response.Metrics.MetricObject + .newBuilder() + .setName(p.nodeName) + .setPlanId(p.id) + .putAllExecutionMetrics(mv.asJava) + .build() + Seq(mo) ++ transformChildren(p) + } + +} diff --git a/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala new file mode 100644 index 000000000000..3c18994ba7f8 --- /dev/null +++ b/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkFunSuite, TestUtils} +import org.apache.spark.connect.proto +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Testing trait for SparkConnect tests with some helper methods to make it easier to create new + * test cases. + */ +trait SparkConnectPlanTest { + def transform(rel: proto.Relation): LogicalPlan = { + new SparkConnectPlanner(rel, None.orNull).transform() + } + + def readRel: proto.Relation = + proto.Relation + .newBuilder() + .setRead( + proto.Read + .newBuilder() + .setNamedTable(proto.Read.NamedTable.newBuilder().addParts("table")) + .build()) + .build() +} + +trait SparkConnectSessionTest { + protected var spark: SparkSession + +} + +/** + * This is a rudimentary test class for SparkConnect. The main goal of these basic tests is to + * ensure that the transformation from Proto to LogicalPlan works and that the right nodes are + * generated. + */ +class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { + + protected var spark: SparkSession = null + + override def beforeAll(): Unit = { + super.beforeAll() + TestUtils.configTestLog4j2("INFO") + } + + test("Simple Limit") { + assertThrows[IndexOutOfBoundsException] { + new SparkConnectPlanner( + proto.Relation.newBuilder.setFetch(proto.Fetch.newBuilder.setLimit(10).build()).build(), + None.orNull) + .transform() + } + } + + test("InvalidInputs") { + // No Relation Set + intercept[IndexOutOfBoundsException]( + new SparkConnectPlanner(proto.Relation.newBuilder().build(), None.orNull).transform()) + + intercept[InvalidPlanInput]( + new SparkConnectPlanner( + proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build(), + None.orNull).transform()) + + } + + test("Simple Read") { + val read = proto.Read.newBuilder().build() + // Invalid read without Table name. + intercept[InvalidPlanInput](transform(proto.Relation.newBuilder.setRead(read).build())) + val readWithTable = read.toBuilder + .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) + .build() + val res = transform(proto.Relation.newBuilder.setRead(readWithTable).build()) + assert(res !== null) + assert(res.nodeName == "UnresolvedRelation") + } + + test("Simple Sort") { + val sort = proto.Sort.newBuilder + .addAllSortFields(Seq(proto.Sort.SortField.newBuilder().build()).asJava) + .build() + intercept[IndexOutOfBoundsException]( + transform(proto.Relation.newBuilder().setSort(sort).build()), + "No Input set.") + + val f = proto.Sort.SortField + .newBuilder() + .setNulls(proto.Sort.SortNulls.SORT_NULLS_LAST) + .setDirection(proto.Sort.SortDirection.SORT_DIRECTION_DESCENDING) + .setExpression(proto.Expression.newBuilder + .setUnresolvedAttribute( + proto.Expression.UnresolvedAttribute.newBuilder.addAllParts(Seq("col").asJava).build()) + .build()) + .build() + + val res = transform( + proto.Relation.newBuilder + .setSort(proto.Sort.newBuilder.addAllSortFields(Seq(f).asJava).setInput(readRel)) + .build()) + assert(res.nodeName == "Sort") + } + + test("Simple Union") { + intercept[AssertionError]( + transform(proto.Relation.newBuilder.setUnion(proto.Union.newBuilder.build()).build)) + val union = proto.Relation.newBuilder + .setUnion(proto.Union.newBuilder.addAllInputs(Seq(readRel, readRel).asJava).build()) + .build() + val msg = intercept[InvalidPlanInput] { + transform(union) + } + assert(msg.getMessage.contains("Unsupported set operation")) + + val res = transform( + proto.Relation.newBuilder + .setUnion( + proto.Union.newBuilder + .addAllInputs(Seq(readRel, readRel).asJava) + .setUnionType(proto.Union.UnionType.UNION_TYPE_ALL) + .build()) + .build()) + assert(res.nodeName == "Union") + } + + test("Simple Join") { + + val incompleteJoin = + proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build() + intercept[AssertionError](transform(incompleteJoin)) + + // Cartesian Product not supported. + intercept[InvalidPlanInput] { + val simpleJoin = proto.Relation.newBuilder + .setJoin(proto.Join.newBuilder.setLeft(readRel).setRight(readRel)) + .build() + transform(simpleJoin) + } + + // Construct a simple Join. + val unresolvedAttribute = proto.Expression + .newBuilder() + .setUnresolvedAttribute( + proto.Expression.UnresolvedAttribute.newBuilder().addAllParts(Seq("left").asJava).build()) + .build() + + val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction( + proto.Expression.UnresolvedFunction.newBuilder + .addAllParts(Seq("eq").asJava) + .addArguments(unresolvedAttribute) + .addArguments(unresolvedAttribute) + .build()) + + val simpleJoin = proto.Relation.newBuilder + .setJoin( + proto.Join.newBuilder.setLeft(readRel).setRight(readRel).setOn(joinCondition).build()) + .build() + + val res = transform(simpleJoin) + assert(res.nodeName == "Join") + assert(res != null) + + } + + test("Simple Projection") { + val project = proto.Project.newBuilder + .setInput(readRel) + .addExpressions( + proto.Expression.newBuilder + .setLiteral(proto.Expression.Literal.newBuilder.setI32(32)) + .build()) + .build() + + val res = transform(proto.Relation.newBuilder.setProject(project).build()) + assert(res.nodeName == "Project") + + } + + test("Simple Aggregation") { + val unresolvedAttribute = proto.Expression + .newBuilder() + .setUnresolvedAttribute( + proto.Expression.UnresolvedAttribute.newBuilder().addAllParts(Seq("left").asJava).build()) + .build() + + val agg = proto.Aggregate.newBuilder + .setInput(readRel) + .addAllMeasures( + Seq( + proto.Aggregate.Measure.newBuilder + .setFunction(proto.Aggregate.AggregateFunction.newBuilder + .setName("sum") + .addArguments(unresolvedAttribute)) + .build()).asJava) + .addGroupingSets(proto.Aggregate.GroupingSet.newBuilder + .addAggregateExpressions(unresolvedAttribute) + .build()) + .build() + + val res = transform(proto.Relation.newBuilder.setAggregate(agg).build()) + assert(res.nodeName == "Aggregate") + } + +} diff --git a/dev/.rat-excludes b/dev/.rat-excludes index e1cc000c064f..23806bcd3069 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -139,3 +139,6 @@ exported_table/* ansible-for-test-node/* node_modules spark-events-broken/* +# Spark Connect related files with custom licence +any.proto +empty.proto diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index c6555e0463db..e20319eceefa 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -42,7 +42,7 @@ ARG APT_INSTALL="apt-get install --no-install-recommends -y" # We should use the latest Sphinx version once this is fixed. # TODO(SPARK-35375): Jinja2 3.0.0+ causes error when building with Sphinx. # See also https://issues.apache.org/jira/browse/SPARK-35375. -ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17" +ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.48.1 protobuf==4.21.5" ARG GEM_PKGS="bundler:2.2.9" # Install extra needed repos and refresh. diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index bb9a8cad170a..4b04a4997d82 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -6,7 +6,9 @@ ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar aircompressor/0.21//aircompressor-0.21.jar algebra_2.12/2.0.1//algebra_2.12-2.0.1.jar +animal-sniffer-annotations/1.19//animal-sniffer-annotations-1.19.jar annotations/17.0.0//annotations-17.0.0.jar +annotations/4.1.1.4//annotations-4.1.1.4.jar antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar antlr4-runtime/4.8//antlr4-runtime-4.8.jar aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar @@ -63,10 +65,20 @@ datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar derby/10.14.2.0//derby-10.14.2.0.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar +error_prone_annotations/2.10.0//error_prone_annotations-2.10.0.jar +failureaccess/1.0.1//failureaccess-1.0.1.jar flatbuffers-java/1.12.0//flatbuffers-java-1.12.0.jar gcs-connector/hadoop2-2.2.7/shaded/gcs-connector-hadoop2-2.2.7-shaded.jar generex/1.0.2//generex-1.0.2.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar +grpc-api/1.47.0//grpc-api-1.47.0.jar +grpc-context/1.47.0//grpc-context-1.47.0.jar +grpc-core/1.47.0//grpc-core-1.47.0.jar +grpc-netty-shaded/1.47.0//grpc-netty-shaded-1.47.0.jar +grpc-protobuf-lite/1.47.0//grpc-protobuf-lite-1.47.0.jar +grpc-protobuf/1.47.0//grpc-protobuf-1.47.0.jar +grpc-services/1.47.0//grpc-services-1.47.0.jar +grpc-stub/1.47.0//grpc-stub-1.47.0.jar gson/2.2.4//gson-2.2.4.jar guava/14.0.1//guava-14.0.1.jar guice-servlet/3.0//guice-servlet-3.0.jar @@ -112,6 +124,7 @@ httpclient/4.5.13//httpclient-4.5.13.jar httpcore/4.4.14//httpcore-4.4.14.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.0//ivy-2.5.0.jar +j2objc-annotations/1.3//j2objc-annotations-1.3.jar jackson-annotations/2.13.4//jackson-annotations-2.13.4.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.13.4//jackson-core-2.13.4.jar @@ -231,7 +244,10 @@ parquet-encoding/1.12.3//parquet-encoding-1.12.3.jar parquet-format-structures/1.12.3//parquet-format-structures-1.12.3.jar parquet-hadoop/1.12.3//parquet-hadoop-1.12.3.jar parquet-jackson/1.12.3//parquet-jackson-1.12.3.jar +perfmark-api/0.25.0//perfmark-api-0.25.0.jar pickle/1.2//pickle-1.2.jar +proto-google-common-protos/2.0.1//proto-google-common-protos-2.0.1.jar +protobuf-java-util/3.19.2//protobuf-java-util-3.19.2.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 555579e6446f..be322eb78674 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -10,7 +10,9 @@ aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar aliyun-java-sdk-ram/3.1.0//aliyun-java-sdk-ram-3.1.0.jar aliyun-sdk-oss/3.13.0//aliyun-sdk-oss-3.13.0.jar +animal-sniffer-annotations/1.19//animal-sniffer-annotations-1.19.jar annotations/17.0.0//annotations-17.0.0.jar +annotations/4.1.1.4//annotations-4.1.1.4.jar antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar antlr4-runtime/4.8//antlr4-runtime-4.8.jar aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar @@ -60,10 +62,20 @@ datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar derby/10.14.2.0//derby-10.14.2.0.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar +error_prone_annotations/2.10.0//error_prone_annotations-2.10.0.jar +failureaccess/1.0.1//failureaccess-1.0.1.jar flatbuffers-java/1.12.0//flatbuffers-java-1.12.0.jar gcs-connector/hadoop3-2.2.7/shaded/gcs-connector-hadoop3-2.2.7-shaded.jar generex/1.0.2//generex-1.0.2.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar +grpc-api/1.47.0//grpc-api-1.47.0.jar +grpc-context/1.47.0//grpc-context-1.47.0.jar +grpc-core/1.47.0//grpc-core-1.47.0.jar +grpc-netty-shaded/1.47.0//grpc-netty-shaded-1.47.0.jar +grpc-protobuf-lite/1.47.0//grpc-protobuf-lite-1.47.0.jar +grpc-protobuf/1.47.0//grpc-protobuf-1.47.0.jar +grpc-services/1.47.0//grpc-services-1.47.0.jar +grpc-stub/1.47.0//grpc-stub-1.47.0.jar gson/2.2.4//gson-2.2.4.jar guava/14.0.1//guava-14.0.1.jar hadoop-aliyun/3.3.4//hadoop-aliyun-3.3.4.jar @@ -100,6 +112,7 @@ httpcore/4.4.14//httpcore-4.4.14.jar ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.0//ivy-2.5.0.jar +j2objc-annotations/1.3//j2objc-annotations-1.3.jar jackson-annotations/2.13.4//jackson-annotations-2.13.4.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.13.4//jackson-core-2.13.4.jar @@ -218,7 +231,10 @@ parquet-encoding/1.12.3//parquet-encoding-1.12.3.jar parquet-format-structures/1.12.3//parquet-format-structures-1.12.3.jar parquet-hadoop/1.12.3//parquet-hadoop-1.12.3.jar parquet-jackson/1.12.3//parquet-jackson-1.12.3.jar +perfmark-api/0.25.0//perfmark-api-0.25.0.jar pickle/1.2//pickle-1.2.jar +proto-google-common-protos/2.0.1//proto-google-common-protos-2.0.1.jar +protobuf-java-util/3.19.2//protobuf-java-util-3.19.2.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 2bf0be3822db..1fc885edb520 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -65,3 +65,6 @@ RUN Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='ht # See more in SPARK-39735 ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" + +# Add Python deps for Spark Connect. +RUN python3.9 -m pip install grpcio protobuf diff --git a/dev/requirements.txt b/dev/requirements.txt index 7771b97a7320..7803e4f0736b 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -45,3 +45,7 @@ PyGithub # pandas API on Spark Code formatter. black==22.6.0 + +# Spark Connect +grpcio==1.48.1 +protobuf==4.21.5 \ No newline at end of file diff --git a/dev/run-tests.py b/dev/run-tests.py index 4cfb0639ed3a..fc34b4c61b0a 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -251,6 +251,7 @@ def build_spark_sbt(extra_profiles): sbt_goals = [ "Test/package", # Build test jars as some tests depend on them "streaming-kinesis-asl-assembly/assembly", + "connect/assembly", # Build Spark Connect assembly ] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2b9d52693794..422d353ec044 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -472,6 +472,24 @@ def __hash__(self): ], ) +pyspark_sql = Module( + name="pyspark-sql-connect", + dependencies=[pyspark_core, hive, avro], + source_file_regexes=["python/pyspark/sql/connect"], + python_test_goals=[ + # doctests + # No doctests yet. + # unittests + "pyspark.sql.tests.connect.test_column_expressions", + "pyspark.sql.tests.connect.test_plan_only", + "pyspark.sql.tests.connect.test_select_ops", + "pyspark.sql.tests.connect.test_spark_connect", + ], + excluded_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and + # they aren't available there + ], +) pyspark_resource = Module( name="pyspark-resource", diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py index 94928fa8730c..800b7e0d9325 100755 --- a/dev/sparktestsupport/utils.py +++ b/dev/sparktestsupport/utils.py @@ -109,22 +109,22 @@ def determine_modules_to_test(changed_modules, deduplicated=True): >>> [x.name for x in determine_modules_to_test([modules.sql])] ... # doctest: +NORMALIZE_WHITESPACE ['sql', 'avro', 'docker-integration-tests', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', - 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', + 'hive-thriftserver', 'pyspark-sql', 'pyspark-sql-connect', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-ml'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sparkr, modules.sql], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', 'mllib', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-sql', - 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'pyspark-sql-connect', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sql, modules.core], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'catalyst', 'core', 'docker-integration-tests', 'examples', 'graphx', 'hive', 'hive-thriftserver', 'mllib', 'mllib-local', 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-resource', 'pyspark-sql', - 'pyspark-streaming', 'repl', 'root', 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', - 'streaming-kafka-0-10', 'streaming-kinesis-asl'] + 'pyspark-sql-connect', 'pyspark-streaming', 'repl', 'root', 'sparkr', 'sql', + 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', 'streaming-kinesis-asl'] """ modules_to_test = set() for module in changed_modules: diff --git a/dev/tox.ini b/dev/tox.ini index 464b9b959fa1..f44cbe54ddf6 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -51,4 +51,5 @@ exclude = python/pyspark/worker.pyi, python/pyspark/java_gateway.pyi, dev/ansible-for-test-node/*, + python/pyspark/sql/connect/proto/*, max-line-length = 100 diff --git a/pom.xml b/pom.xml index 9f382692e536..5cac5e30bde8 100644 --- a/pom.xml +++ b/pom.xml @@ -100,6 +100,7 @@ connector/kafka-0-10-assembly connector/kafka-0-10-sql connector/avro + connect diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0c887e0e70ed..e7e24954eb16 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -39,6 +39,8 @@ import sbtassembly.AssemblyPlugin.autoImport._ import spray.revolver.RevolverPlugin._ +import sbtprotoc.ProtocPlugin.autoImport._ + object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile @@ -50,12 +52,14 @@ object BuildCommons { val streamingProjects@Seq(streaming, streamingKafka010) = Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) + val connect = ProjectRef(buildLocation, "connect") + val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "tags", "sketch", "kvstore" - ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, sparkGangliaLgpl, streamingKinesisAsl, @@ -79,6 +83,11 @@ object BuildCommons { val testTempDir = s"$sparkHome/target/tmp" val javaVersion = settingKey[String]("source and target JVM version for javac and scalac") + + // Google Protobuf version used for generating the protobuf. + val protoVersion = "3.21.1" + // GRPC version used for Spark Connect. + val gprcVersion = "1.47.0" } object SparkBuild extends PomBuild { @@ -357,7 +366,11 @@ object SparkBuild extends PomBuild { // To prevent intermittent compilation failures, see also SPARK-33297 // Apparently we can remove this when we use JDK 11. - Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat + Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat, + + // Setting version for the protobuf compiler. This has to be propagated to every sub-project + // even if the project is not using it. + PB.protocVersion := protoVersion, ) def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { @@ -377,7 +390,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, tokenProviderKafka010, sqlKafka010 + unsafe, tags, tokenProviderKafka010, sqlKafka010, connect ).contains(x) } @@ -418,6 +431,8 @@ object SparkBuild extends PomBuild { /* Hive console settings */ enable(Hive.settings)(hive) + enable(SparkConnect.settings)(connect) + // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) @@ -593,6 +608,60 @@ object Core { ) } + +object SparkConnect { + + import BuildCommons.protoVersion + + private val shadePrefix = "org.sparkproject.connect" + val shadeJar = taskKey[Unit]("Shade the Jars") + + lazy val settings = Seq( + // Setting version for the protobuf compiler. This has to be propagated to every sub-project + // even if the project is not using it. + PB.protocVersion := BuildCommons.protoVersion, + + // For some reason the resolution from the imported Maven build does not work for some + // of these dependendencies that we need to shade later on. + libraryDependencies ++= Seq( + "io.grpc" % "protoc-gen-grpc-java" % BuildCommons.gprcVersion asProtocPlugin(), + "org.scala-lang" % "scala-library" % "2.12.16" % "provided", + "com.google.guava" % "guava" % "31.0.1-jre", + "com.google.guava" % "failureaccess" % "1.0.1", + "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" + ), + + dependencyOverrides ++= Seq( + "com.google.guava" % "guava" % "31.0.1-jre", + "com.google.guava" % "failureaccess" % "1.0.1", + "com.google.protobuf" % "protobuf-java" % protoVersion + ), + + (Compile / PB.targets) := Seq( + PB.gens.java -> (Compile / sourceManaged).value, + PB.gens.plugin("grpc-java") -> (Compile / sourceManaged).value + ), + + (assembly / test) := false, + + (assembly / logLevel) := Level.Info, + + (assembly / assemblyShadeRules) := Seq( + ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@0").inAll, + ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll, + ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll, + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll, + ), + + (assembly / assemblyMergeStrategy) := { + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard + // Drop all proto files that are not needed as artifacts of the build. + case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard + case _ => MergeStrategy.first + }, + ) +} + object Unsafe { lazy val settings = Seq( // This option is needed to suppress warnings from sun.misc.Unsafe usage @@ -741,6 +810,8 @@ object ExcludedDependencies { */ object OldDeps { + import BuildCommons.protoVersion + lazy val project = Project("oldDeps", file("dev")) .settings(oldDepsSettings) .disablePlugins(com.typesafe.sbt.pom.PomReaderPlugin) @@ -753,6 +824,9 @@ object OldDeps { } def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( + // Setting version for the protobuf compiler. This has to be propagated to every sub-project + // even if the project is not using it. + PB.protocVersion := protoVersion, name := "old-deps", libraryDependencies := allPreviousArtifactKeys.value.flatten ) @@ -1033,10 +1107,10 @@ object Unidoc { (ScalaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010), + yarn, tags, streamingKafka010, sqlKafka010, connect), (JavaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010), + yarn, tags, streamingKafka010, sqlKafka010, connect), (ScalaUnidoc / unidoc / unidocAllClasspaths) := { ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value) @@ -1118,14 +1192,25 @@ object CopyDependencies { throw new IOException("Failed to create jars directory.") } + // For the SparkConnect build, we manually call the assembly target to + // produce the shaded Jar which happens automatically in the case of Maven. + // Later, when the dependencies are copied, we manually copy the shaded Jar only. + val fid = (LocalProject("connect")/assembly).value + (Compile / dependencyClasspath).value.map(_.data) .filter { jar => jar.isFile() } + // Do not copy the Spark Connect JAR as it is unshaded in the SBT build. .foreach { jar => val destJar = new File(dest, jar.getName()) if (destJar.isFile()) { destJar.delete() } - Files.copy(jar.toPath(), destJar.toPath()) + if (jar.getName.contains("spark-connect") && + !SbtPomKeys.profiles.value.contains("noshade-connect")) { + Files.copy(fid.toPath, destJar.toPath) + } else { + Files.copy(jar.toPath(), destJar.toPath()) + } } }, (Compile / packageBin / crossTarget) := destPath.value, diff --git a/project/plugins.sbt b/project/plugins.sbt index 9f55e21dc9d2..19f9da7351dc 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -44,3 +44,5 @@ libraryDependencies += "org.ow2.asm" % "asm-commons" % "9.3" addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3") addSbtPlugin("com.typesafe.sbt" % "sbt-pom-reader" % "2.2.0") + +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.1") diff --git a/python/mypy.ini b/python/mypy.ini index efaa3dc97d3c..4f015641af28 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -23,6 +23,18 @@ show_error_codes = True warn_unused_ignores = True warn_redundant_casts = True +; TODO(SPARK-40537) reenable mypi support. +[mypy-pyspark.sql.connect.*] +disallow_untyped_defs = False +ignore_missing_imports = True +ignore_errors = True + +; TODO(SPARK-40537) reenable mypi support. +[mypy-pyspark.sql.tests.connect.*] +disallow_untyped_defs = False +ignore_missing_imports = True +ignore_errors = True + ; Allow untyped def in internal modules and tests [mypy-pyspark.daemon] @@ -138,3 +150,10 @@ ignore_missing_imports = True [mypy-tabulate.*] ignore_missing_imports = True + +[mypy-google.protobuf.*] +ignore_missing_imports = True + +; Ignore errors for proto generated code +[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto] +ignore_errors = True diff --git a/python/pyspark/sql/connect/README.md b/python/pyspark/sql/connect/README.md new file mode 100644 index 000000000000..e79e9aae9dd2 --- /dev/null +++ b/python/pyspark/sql/connect/README.md @@ -0,0 +1,38 @@ + +# [EXPERIMENTAL] Spark Connect + +**Spark Connect is a strictly experimental feature and under heavy development. +All APIs should be considered volatile and should not be used in production.** + +This module contains the implementation of Spark Connect which is a logical plan +facade for the implementation in Spark. Spark Connect is directly integrated into the build +of Spark. To enable it, you only need to activate the driver plugin for Spark Connect. + + + + +## Build + +1. Build Spark as usual per the documentation. +2. Build and package the Spark Connect package + ```bash + ./build/mvn -Phive package + ``` + or + ```shell + ./build/sbt -Phive package + ``` + +## Run Spark Shell + +```bash +./bin/spark-shell --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin +``` + +## Run Tests + + +```bash +./run-tests --testnames 'pyspark.sql.tests.connect.test_spark_connect' +``` + diff --git a/python/pyspark/sql/connect/__init__.py b/python/pyspark/sql/connect/__init__.py new file mode 100644 index 000000000000..c748f8f6590e --- /dev/null +++ b/python/pyspark/sql/connect/__init__.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Currently Spark Connect is very experimental and the APIs to interact with +Spark through this API are can be changed at any time without warning.""" + + +from pyspark.sql.connect.data_frame import DataFrame # noqa: F401 diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py new file mode 100644 index 000000000000..df42411d42cb --- /dev/null +++ b/python/pyspark/sql/connect/client.py @@ -0,0 +1,180 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import io +import logging +import typing +import uuid + +import grpc +import pandas +import pandas as pd +import pyarrow as pa + +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib +from pyspark import cloudpickle +from pyspark.sql.connect.data_frame import DataFrame +from pyspark.sql.connect.readwriter import DataFrameReader +from pyspark.sql.connect.plan import SQL + + +NumericType = typing.Union[int, float] + +logging.basicConfig(level=logging.INFO) + + +class MetricValue: + def __init__(self, name: str, value: NumericType, type: str): + self._name = name + self._type = type + self._value = value + + def __repr__(self) -> str: + return f"<{self._name}={self._value} ({self._type})>" + + @property + def name(self) -> str: + return self._name + + @property + def value(self) -> NumericType: + return self._value + + @property + def metric_type(self) -> str: + return self._type + + +class PlanMetrics: + def __init__(self, name: str, id: str, parent: str, metrics: typing.List[MetricValue]): + self._name = name + self._id = id + self._parent_id = parent + self._metrics = metrics + + def __repr__(self) -> str: + return f"Plan({self._name})={self._metrics}" + + @property + def name(self) -> str: + return self._name + + @property + def plan_id(self) -> str: + return self._id + + @property + def parent_plan_id(self) -> str: + return self._parent_id + + @property + def metrics(self) -> typing.List[MetricValue]: + return self._metrics + + +class AnalyzeResult: + def __init__(self, cols: typing.List[str], types: typing.List[str], explain: str): + self.cols = cols + self.types = types + self.explain_string = explain + + @classmethod + def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": + return AnalyzeResult(pb.column_names, pb.column_types, pb.explain_string) + + +class RemoteSparkSession(object): + """Conceptually the remote spark session that communicates with the server""" + + def __init__(self, user_id: str, host: str = None, port: int = 15002): + self._host = "localhost" if host is None else host + self._port = port + self._user_id = user_id + self._channel = grpc.insecure_channel(f"{self._host}:{self._port}") + self._stub = grpc_lib.SparkConnectServiceStub(self._channel) + + # Create the reader + self.read = DataFrameReader(self) + + def register_udf(self, function, return_type) -> str: + """Create a temporary UDF in the session catalog on the other side. We generate a + temporary name for it.""" + name = f"fun_{uuid.uuid4().hex}" + fun = pb2.CreateScalarFunction() + fun.parts.append(name) + fun.serialized_function = cloudpickle.dumps((function, return_type)) + + req = pb2.Request() + req.user_context.user_id = self._user_id + req.plan.command.create_function.CopyFrom(fun) + + self._execute_and_fetch(req) + return name + + def _build_metrics(self, metrics: "pb2.Response.Metrics") -> typing.List[PlanMetrics]: + return [ + PlanMetrics( + x.name, + x.plan_id, + x.parent, + [MetricValue(k, v.value, v.metric_type) for k, v in x.execution_metrics.items()], + ) + for x in metrics.metrics + ] + + def sql(self, sql_string: str) -> "DataFrame": + return DataFrame.withPlan(SQL(sql_string), self) + + def collect(self, plan: pb2.Plan) -> pandas.DataFrame: + req = pb2.Request() + req.user_context.user_id = self._user_id + req.plan.CopyFrom(plan) + return self._execute_and_fetch(req) + + def analyze(self, plan: pb2.Plan) -> AnalyzeResult: + req = pb2.Request() + req.user_context.user_id = self._user_id + req.plan.CopyFrom(plan) + + resp = self._stub.AnalyzePlan(req) + return AnalyzeResult.fromProto(resp) + + def _process_batch(self, b) -> pandas.DataFrame: + if b.batch is not None and len(b.batch.data) > 0: + with pa.ipc.open_stream(b.data) as rd: + return rd.read_pandas() + elif b.csv_batch is not None and len(b.csv_batch.data) > 0: + return pd.read_csv(io.StringIO(b.csv_batch.data), delimiter="|") + + def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]: + m = None + result_dfs = [] + + for b in self._stub.ExecutePlan(req): + if b.metrics is not None: + m = b.metrics + result_dfs.append(self._process_batch(b)) + + if len(result_dfs) > 0: + df = pd.concat(result_dfs) + # Attach the metrics to the DataFrame attributes. + df.attrs["metrics"] = self._build_metrics(m) + return df + else: + return None diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py new file mode 100644 index 000000000000..391771f14fc9 --- /dev/null +++ b/python/pyspark/sql/connect/column.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Union, cast, get_args, TYPE_CHECKING + +import pyspark.sql.connect.proto as proto + +PrimitiveType = Union[str, int, bool, float] +ExpressionOrString = Union[str, "Expression"] +ColumnOrString = Union[str, "ColumnRef"] + +if TYPE_CHECKING: + from pyspark.sql.connect.client import RemoteSparkSession + import pyspark.sql.connect.proto as proto + + +class Expression(object): + """ + Expression base class. + """ + + def __init__(self) -> None: # type: ignore[name-defined] + pass + + def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": # type: ignore + ... + + def __str__(self) -> str: + ... + + +class LiteralExpression(Expression): + """A literal expression. + + The Python types are converted best effort into the relevant proto types. On the Spark Connect + server side, the proto types are converted to the Catalyst equivalents.""" + + def __init__(self, value: PrimitiveType) -> None: # type: ignore[name-defined] + super().__init__() + self._value = value + + def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + """Converts the literal expression to the literal in proto. + + TODO(SPARK-40533) This method always assumes the largest type and can thus + create weird interpretations of the literal.""" + value_type = type(self._value) + exp = proto.Expression() + if value_type is int: + exp.literal.i32 = cast(int, self._value) + elif value_type is str: + exp.literal.string = cast(str, self._value) + elif value_type is float: + exp.literal.fp64 = cast(float, self._value) + else: + raise ValueError(f"Could not convert literal for type {type(self._value)}") + + return exp + + def __str__(self) -> str: + return f"Literal({self._value})" + + +def _bin_op(name: str, doc: str = "binary function", reverse=False): + def _(self: "ColumnRef", other) -> Expression: + if isinstance(other, get_args(PrimitiveType)): + other = LiteralExpression(other) + if not reverse: + return ScalarFunctionExpression(name, self, other) + else: + return ScalarFunctionExpression(name, other, self) + + return _ + + +class ColumnRef(Expression): + """Represents a column reference. There is no guarantee that this column + actually exists. In the context of this project, we refer by its name and + treat it as an unresolved attribute. Attributes that have the same fully + qualified name are identical""" + + @classmethod + def from_qualified_name(cls, name) -> "ColumnRef": + return ColumnRef(*name.split(".")) + + def __init__(self, *parts: str) -> None: # type: ignore[name-defined] + super().__init__() + self._parts: List[str] = list(filter(lambda x: x is not None, list(parts))) + + def name(self) -> str: + """Returns the qualified name of the column reference.""" + return ".".join(self._parts) + + __gt__ = _bin_op("gt") + __lt__ = _bin_op("lt") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("modulo") + __radd__ = _bin_op("plus", reverse=True) + __rsub__ = _bin_op("minus", reverse=True) + __rmul__ = _bin_op("multiply", reverse=True) + __rdiv__ = _bin_op("divide", reverse=True) + __rtruediv__ = _bin_op("divide", reverse=True) + __pow__ = _bin_op("pow") + __rpow__ = _bin_op("pow", reverse=True) + __ge__ = _bin_op("greterEquals") + __le__ = _bin_op("lessEquals") + + def __eq__(self, other) -> Expression: # type: ignore[override] + """Returns a binary expression with the current column as the left + side and the other expression as the right side. + """ + if isinstance(other, get_args(PrimitiveType)): + other = LiteralExpression(other) + return ScalarFunctionExpression("eq", self, other) + + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + """Returns the Proto representation of the expression.""" + expr = proto.Expression() + expr.unresolved_attribute.parts.extend(self._parts) + return expr + + def desc(self): + return SortOrder(self, ascending=False) + + def asc(self): + return SortOrder(self, ascending=True) + + def __str__(self) -> str: + return f"Column({'.'.join(self._parts)})" + + +class SortOrder(Expression): + def __init__(self, col: ColumnRef, ascending=True, nullsLast=True) -> None: + super().__init__() + self.ref = col + self.ascending = ascending + self.nullsLast = nullsLast + + def __str__(self) -> str: + return str(self.ref) + " ASC" if self.ascending else " DESC" + + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + return self.ref.to_plan() + + +class ScalarFunctionExpression(Expression): + def __init__( + self, + op: str, + *args: Expression, + ) -> None: + super().__init__() + self._args = args + self._op = op + + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + fun = proto.Expression() + fun.unresolved_function.parts.append(self._op) + fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args]) + return fun + + def __str__(self) -> str: + return f"({self._op} ({', '.join([str(x) for x in self._args])}))" diff --git a/python/pyspark/sql/connect/data_frame.py b/python/pyspark/sql/connect/data_frame.py new file mode 100644 index 000000000000..b229cad19803 --- /dev/null +++ b/python/pyspark/sql/connect/data_frame.py @@ -0,0 +1,236 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, + TYPE_CHECKING, +) + +import pyspark.sql.connect.plan as plan +from pyspark.sql.connect.column import ( + ColumnOrString, + ColumnRef, + Expression, + ExpressionOrString, + LiteralExpression, +) + +if TYPE_CHECKING: + from pyspark.sql.connect.client import RemoteSparkSession + + +ColumnOrName = Union[ColumnRef, str] + + +class GroupingFrame(object): + + MeasuresType = Union[Sequence[Tuple[ExpressionOrString, str]], Dict[str, str]] + OptMeasuresType = Optional[MeasuresType] + + def __init__(self, df: "DataFrame", *grouping_cols: Union[ColumnRef, str]) -> None: + self._df = df + self._grouping_cols = [x if isinstance(x, ColumnRef) else df[x] for x in grouping_cols] + + def agg(self, exprs: MeasuresType = None) -> "DataFrame": + + # Normalize the dictionary into a list of tuples. + if isinstance(exprs, Dict): + measures = list(exprs.items()) + elif isinstance(exprs, List): + measures = exprs + else: + measures = [] + + res = DataFrame.withPlan( + plan.Aggregate( + child=self._df._plan, + grouping_cols=self._grouping_cols, + measures=measures, + ), + session=self._df._session, + ) + return res + + def _map_cols_to_dict(self, fun: str, cols: List[Union[ColumnRef, str]]) -> Dict[str, str]: + return {x if isinstance(x, str) else cast(ColumnRef, x).name(): fun for x in cols} + + def min(self, *cols: Union[ColumnRef, str]) -> "DataFrame": + expr = self._map_cols_to_dict("min", list(cols)) + return self.agg(expr) + + def max(self, *cols: Union[ColumnRef, str]) -> "DataFrame": + expr = self._map_cols_to_dict("max", list(cols)) + return self.agg(expr) + + def sum(self, *cols: Union[ColumnRef, str]) -> "DataFrame": + expr = self._map_cols_to_dict("sum", list(cols)) + return self.agg(expr) + + def count(self) -> "DataFrame": + return self.agg([(LiteralExpression(1), "count")]) + + +class DataFrame(object): + """Every DataFrame object essentially is a Relation that is refined using the + member functions. Calling a method on a dataframe will essentially return a copy + of the DataFrame with the changes applied. + """ + + def __init__(self, data: List[Any] = None, schema: List[str] = None): + """Creates a new data frame""" + self._schema = schema + self._plan: Optional[plan.LogicalPlan] = None + self._cache: Dict[str, Any] = {} + self._session: "RemoteSparkSession" = None + + @classmethod + def withPlan(cls, plan: plan.LogicalPlan, session=None) -> "DataFrame": + """Main initialization method used to construct a new data frame with a child plan.""" + new_frame = DataFrame() + new_frame._plan = plan + new_frame._session = session + return new_frame + + def select(self, *cols: ColumnRef) -> "DataFrame": + return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session) + + def agg(self, exprs: Dict[str, str]) -> "DataFrame": + return self.groupBy().agg(exprs) + + def alias(self, alias): + return DataFrame.withPlan(plan.Project(self._plan).withAlias(alias), session=self._session) + + def approxQuantile(self, col, probabilities, relativeError): + ... + + def colRegex(self, regex) -> "DataFrame": + ... + + @property + def columns(self) -> List[str]: + """Returns the list of columns of the current data frame.""" + if self._plan is None: + return [] + if "columns" not in self._cache and self._plan is not None: + pdd = self.limit(0).collect() + # Translate to standard pytho array + self._cache["columns"] = pdd.columns.values + return self._cache["columns"] + + def count(self): + """Returns the number of rows in the data frame""" + return self.agg([(LiteralExpression(1), "count")]).collect().iloc[0, 0] + + def crossJoin(self, other): + ... + + def coalesce(self, num_partitions: int) -> "DataFrame": + ... + + def describe(self, cols): + ... + + def distinct(self) -> "DataFrame": + """Returns all distinct rows.""" + all_cols = self.columns() + gf = self.groupBy(*all_cols) + return gf.agg() + + def drop(self, *cols: ColumnOrString): + all_cols = self.columns() + dropped = set([c.name() if isinstance(c, ColumnRef) else self[c].name() for c in cols]) + filter(lambda x: x in dropped, all_cols) + + def filter(self, condition: Expression) -> "DataFrame": + return DataFrame.withPlan( + plan.Filter(child=self._plan, filter=condition), session=self._session + ) + + def first(self): + return self.head(1) + + def groupBy(self, *cols: ColumnOrString): + return GroupingFrame(self, *cols) + + def head(self, n: int): + self.limit(n) + return self.collect() + + def join(self, other, on, how=None): + return DataFrame.withPlan( + plan.Join(left=self._plan, right=other._plan, on=on, how=how), + session=self._session, + ) + + def limit(self, n): + return DataFrame.withPlan(plan.Limit(child=self._plan, limit=n), session=self._session) + + def sort(self, *cols: ColumnOrName): + """Sort by a specific column""" + return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session) + + def show(self, n: int, truncate: Optional[Union[bool, int]], vertical: Optional[bool]): + ... + + def union(self, other) -> "DataFrame": + return self.unionAll(other) + + def unionAll(self, other: "DataFrame") -> "DataFrame": + if other._plan is None: + raise ValueError("Argument to Union does not contain a valid plan.") + return DataFrame.withPlan(plan.UnionAll(self._plan, other._plan), session=self._session) + + def where(self, condition): + return self.filter(condition) + + def _get_alias(self): + p = self._plan + while p is not None: + if isinstance(p, plan.Project) and p.alias: + return p.alias + p = p._child + + def __getattr__(self, name) -> "ColumnRef": + return self[name] + + def __getitem__(self, name) -> "ColumnRef": + # Check for alias + alias = self._get_alias() + return ColumnRef(alias, name) + + def _print_plan(self) -> str: + if self._plan: + return self._plan.print() + return "" + + def collect(self): + query = self._plan.collect(self._session) + return self._session.collect(query) + + def toPandas(self): + return self.collect() + + def explain(self) -> str: + query = self._plan.collect(self._session) + return self._session.analyze(query).explain_string diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py new file mode 100644 index 000000000000..9f4e457030c5 --- /dev/null +++ b/python/pyspark/sql/connect/function_builder.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +from typing import TYPE_CHECKING + +import pyspark.sql.types +from pyspark.sql.connect.column import ( + ColumnOrString, + ColumnRef, + Expression, + ExpressionOrString, + ScalarFunctionExpression, +) + +if TYPE_CHECKING: + from pyspark.sql.connect.client import RemoteSparkSession + + +def _build(name: str, *args: ExpressionOrString) -> ScalarFunctionExpression: + """ + Simple wrapper function that converts the arguments into the appropriate types. + Parameters + ---------- + name Name of the function to be called. + args The list of arguments. + + Returns + ------- + :class:`ScalarFunctionExpression` + """ + cols = [x if isinstance(x, Expression) else ColumnRef.from_qualified_name(x) for x in args] + return ScalarFunctionExpression(name, *cols) + + +class FunctionBuilder: + """This class is used to build arbitrary functions used in expressions""" + + def __getattr__(self, name): + def _(*args: ExpressionOrString) -> ScalarFunctionExpression: + return _build(name, *args) + + _.__doc__ = f"""Function to apply {name}""" + return _ + + +functions = FunctionBuilder() + + +class UserDefinedFunction(Expression): + """A user defied function is an expresison that has a reference to the actual + Python callable attached. During plan generation, the client sends a command to + the server to register the UDF before execution. The expression object can be + reused and is not attached to a specific execution. If the internal name of + the temporary function is set, it is assumed that the registration has already + happened.""" + + def __init__(self, func, return_type=pyspark.sql.types.StringType(), args=None): + super().__init__() + + self._func_ref = func + self._return_type = return_type + self._args = list(args) + self._func_name = None + + def to_plan(self, session: "RemoteSparkSession") -> Expression: + # Needs to materialize the UDF to the server + # Only do this once per session + func_name = session.register_udf(self._func_ref, self._return_type) + # Func name is used for the actual reference + return _build(func_name, *self._args).to_plan(session) + + def __str__(self): + return f"UserDefinedFunction({self._func_name})" + + +def _create_udf(function, return_type): + def wrapper(*cols: "ColumnOrString"): + return UserDefinedFunction(func=function, return_type=return_type, args=cols) + + return wrapper + + +def udf(function, return_type=pyspark.sql.types.StringType()): + """ + Returns a callable that represents the column once arguments are applied + + Parameters + ---------- + function + return_type + + Returns + ------- + + """ + # This is when @udf / @udf(DataType()) is used + if function is None or isinstance(function, (str, pyspark.sql.types.DataType)): + return_type = function or return_type + # Overload with + if return_type is None: + return_type = pyspark.sql.types.StringType() + return functools.partial(_create_udf, return_type=return_type) + else: + return _create_udf(function, return_type) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py new file mode 100644 index 000000000000..4fe57d922837 --- /dev/null +++ b/python/pyspark/sql/connect/functions.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.connect.column import ColumnRef, LiteralExpression +from pyspark.sql.connect.column import PrimitiveType + +# TODO(SPARK-40538) Add support for the missing PySpark functions. + + +def col(x: str) -> ColumnRef: + return ColumnRef(x) + + +def lit(x: PrimitiveType) -> LiteralExpression: + return LiteralExpression(x) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py new file mode 100644 index 000000000000..d404236610d5 --- /dev/null +++ b/python/pyspark/sql/connect/plan.py @@ -0,0 +1,464 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + List, + Optional, + Sequence, + Tuple, + Union, + cast, + TYPE_CHECKING, +) + +import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.column import ( + ColumnOrString, + ColumnRef, + Expression, + ExpressionOrString, + SortOrder, +) + + +if TYPE_CHECKING: + from pyspark.sql.connect.client import RemoteSparkSession + + +class InputValidationError(Exception): + pass + + +class LogicalPlan(object): + + INDENT = 2 + + def __init__(self, child: Optional["LogicalPlan"]) -> None: + self._child = child + + def unresolved_attr(self, *colNames: str) -> proto.Expression: + """Creates an unresolved attribute from a column name.""" + exp = proto.Expression() + exp.unresolved_attribute.parts.extend(list(colNames)) + return exp + + def to_attr_or_expression( + self, col: ColumnOrString, session: "RemoteSparkSession" + ) -> proto.Expression: + """Returns either an instance of an unresolved attribute or the serialized + expression value of the column.""" + if type(col) is str: + return self.unresolved_attr(cast(str, col)) + else: + return cast(ColumnRef, col).to_plan(session) + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + ... + + def _verify(self, session: "RemoteSparkSession") -> bool: + """This method is used to verify that the current logical plan + can be serialized to Proto and back and afterwards is identical.""" + plan = proto.Plan() + plan.root.CopyFrom(self.plan(session)) + + serialized_plan = plan.SerializeToString() + test_plan = proto.Plan() + test_plan.ParseFromString(serialized_plan) + + return test_plan == plan + + def collect(self, session: "RemoteSparkSession" = None, debug: bool = False): + plan = proto.Plan() + plan.root.CopyFrom(self.plan(session)) + + if debug: + print(plan) + + return plan + + def print(self, indent=0) -> str: + ... + + def _repr_html_(self): + ... + + +class Read(LogicalPlan): + def __init__(self, table_name: str) -> None: + super().__init__(None) + self.table_name = table_name + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + plan = proto.Relation() + plan.read.named_table.parts.extend(self.table_name.split(".")) + return plan + + def print(self, indent=0) -> str: + return f"{' ' * indent}\n" + + def _repr_html_(self): + return f""" + + """ + + +class Project(LogicalPlan): + """Logical plan object for a projection. + + All input arguments are directly serialized into the corresponding protocol buffer + objects. This class only provides very limited error handling and input validation. + + To be compatible with PySpark, we validate that the input arguments are all + expressions to be able to serialize them to the server. + + """ + + def __init__(self, child: Optional["LogicalPlan"], *columns: ExpressionOrString) -> None: + super().__init__(child) + self._raw_columns = list(columns) + self.alias = None + self._verify_expressions() + + def _verify_expressions(self): + """Ensures that all input arguments are instances of Expression.""" + for c in self._raw_columns: + if not isinstance(c, Expression): + raise InputValidationError(f"Only Expressions can be used for projections: '{c}'.") + + def withAlias(self, alias) -> LogicalPlan: + self.alias = alias + return self + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + proj_exprs = [ + c.to_plan(session) + if isinstance(c, Expression) + else self.unresolved_attr(*cast(str, c).split(".")) + for c in self._raw_columns + ] + common = proto.RelationCommon() + if self.alias is not None: + common.alias = self.alias + + plan = proto.Relation() + plan.project.input.CopyFrom(self._child.plan(session)) + plan.project.expressions.extend(proj_exprs) + plan.common.CopyFrom(common) + return plan + + def print(self, indent=0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return f"{' ' * indent}\n{c_buf}" + + def _repr_html_(self): + return f""" +
    +
  • + Project
    + Columns: {",".join([str(c) for c in self._raw_columns])} + {self._child._repr_html_()} +
  • +
+ """ + + +class Filter(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], filter: Expression) -> None: + super().__init__(child) + self.filter = filter + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.filter.input.CopyFrom(self._child.plan(session)) + plan.filter.condition.CopyFrom(self.filter.to_plan(session)) + return plan + + def print(self, indent=0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return f"{' ' * indent}\n{c_buf}" + + def _repr_html_(self): + return f""" +
    +
  • + Filter
    + Condition: {self.filter} + {self._child._repr_html_()} +
  • +
+ """ + + +class Limit(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], limit: int, offset: int = 0) -> None: + super().__init__(child) + self.limit = limit + self.offset = offset + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.fetch.input.CopyFrom(self._child.plan(session)) + plan.fetch.limit = self.limit + return plan + + def print(self, indent=0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return f"{' ' * indent}\n{c_buf}" + + def _repr_html_(self): + return f""" +
    +
  • + Limit
    + Limit: {self.limit}
    + Offset: {self.offset}
    + {self._child._repr_html_()} +
  • +
+ """ + + +class Sort(LogicalPlan): + def __init__( + self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str] + ) -> None: + super().__init__(child) + self.columns = list(columns) + + def col_to_sort_field( + self, col: Union[SortOrder, ColumnRef, str], session: "RemoteSparkSession" + ) -> proto.Sort.SortField: + if type(col) is SortOrder: + so = cast(SortOrder, col) + sf = proto.Sort.SortField() + sf.expression.CopyFrom(so.ref.to_plan(session)) + sf.direction = ( + proto.Sort.SortDirection.SORT_DIRECTION_ASCENDING + if so.ascending + else proto.Sort.SortDirection.SORT_DIRECTION_DESCENDING + ) + sf.nulls = ( + proto.Sort.SortNulls.SORT_NULLS_FIRST + if not so.nullsLast + else proto.Sort.SortNulls.SORT_NULLS_LAST + ) + return sf + else: + sf = proto.Sort.SortField() + # Check string + if type(col) is ColumnRef: + sf.expression.CopyFrom(cast(ColumnRef, col).to_plan(session)) + else: + sf.expression.CopyFrom(self.unresolved_attr(cast(str, col))) + sf.direction = proto.Sort.SortDirection.SORT_DIRECTION_ASCENDING + sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST + return sf + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.sort.input.CopyFrom(self._child.plan(session)) + plan.sort.sort_fields.extend([self.col_to_sort_field(x, session) for x in self.columns]) + return plan + + def print(self, indent=0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return f"{' ' * indent}\n{c_buf}" + + def _repr_html_(self): + return f""" +
    +
  • + Sort
    + {", ".join([str(c) for c in self.columns])} + {self._child._repr_html_()} +
  • +
+ """ + + +class Aggregate(LogicalPlan): + MeasuresType = Sequence[Tuple[ExpressionOrString, str]] + OptMeasuresType = Optional[MeasuresType] + + def __init__( + self, + child: Optional["LogicalPlan"], + grouping_cols: List[ColumnRef], + measures: OptMeasuresType, + ) -> None: + super().__init__(child) + self.grouping_cols = grouping_cols + self.measures = measures if measures is not None else [] + + def _convert_measure(self, m, session: "RemoteSparkSession"): + exp, fun = m + measure = proto.Aggregate.Measure() + measure.function.name = fun + if type(exp) is str: + measure.function.arguments.append(self.unresolved_attr(exp)) + else: + measure.function.arguments.append(cast(Expression, exp).to_plan(session)) + return measure + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + groupings = [x.to_plan(session) for x in self.grouping_cols] + + agg = proto.Relation() + agg.aggregate.input.CopyFrom(self._child.plan(session)) + agg.aggregate.measures.extend( + list(map(lambda x: self._convert_measure(x, session), self.measures)) + ) + + gs = proto.Aggregate.GroupingSet() + gs.aggregate_expressions.extend(groupings) + agg.aggregate.grouping_sets.append(gs) + return agg + + def print(self, indent=0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}\n{c_buf}" + ) + + def _repr_html_(self): + return f""" +
    +
  • + Aggregation
    + {self._child._repr_html_()} +
  • +
+ """ + + +class Join(LogicalPlan): + def __init__( + self, + left: Optional["LogicalPlan"], + right: "LogicalPlan", + on: ColumnOrString, + how: proto.Join.JoinType = proto.Join.JoinType.JOIN_TYPE_INNER, + ) -> None: + super().__init__(left) + self.left = cast(LogicalPlan, left) + self.right = right + self.on = on + if how is None: + how = proto.Join.JoinType.JOIN_TYPE_INNER + self.how = how + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + rel = proto.Relation() + rel.join.left.CopyFrom(self.left.plan(session)) + rel.join.right.CopyFrom(self.right.plan(session)) + rel.join.on.CopyFrom(self.to_attr_or_expression(self.on, session)) + return rel + + def print(self, indent=0) -> str: + i = " " * indent + o = " " * (indent + LogicalPlan.INDENT) + n = indent + LogicalPlan.INDENT * 2 + return ( + f"{i}\n{o}" + f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}" + ) + + def _repr_html_(self): + return f""" +
    +
  • + Join
    + Left: {self.left._repr_html_()} + Right: {self.right._repr_html_()} +
  • +
+ """ + + +class UnionAll(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], other: "LogicalPlan") -> None: + super().__init__(child) + self.other = other + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + rel = proto.Relation() + rel.union.inputs.extend([self._child.plan(session), self.other.plan(session)]) + rel.union.union_type = proto.Union.UnionType.UNION_TYPE_ALL + + def print(self, indent=0) -> str: + assert self._child is not None + assert self.other is not None + + i = " " * indent + o = " " * (indent + LogicalPlan.INDENT) + n = indent + LogicalPlan.INDENT * 2 + return ( + f"{i}UnionAll\n{o}child1=\n{self._child.print(n)}" + f"\n{o}child2=\n{self.other.print(n)}" + ) + + def _repr_html_(self) -> str: + assert self._child is not None + assert self.other is not None + + return f""" +
    +
  • + Union
    + Left: {self._child._repr_html_()} + Right: {self.other._repr_html_()} +
  • +
+ """ + + +class SQL(LogicalPlan): + def __init__(self, query: str) -> None: + super().__init__(None) + self._query = query + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + rel = proto.Relation() + rel.sql.query = self._query + return rel + + def print(self, indent=0) -> str: + i = " " * indent + sub_query = self._query.replace("\n", "")[:50] + return f"""{i}""" + + def _repr_html_(self) -> str: + return f""" +
    +
  • + SQL
    + Statement:
    {self._query}
    +
  • +
+ """ diff --git a/python/pyspark/sql/connect/proto/__init__.py b/python/pyspark/sql/connect/proto/__init__.py new file mode 100644 index 000000000000..f00b1a74c1d8 --- /dev/null +++ b/python/pyspark/sql/connect/proto/__init__.py @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.connect.proto.base_pb2_grpc import * +from pyspark.sql.connect.proto.base_pb2 import * +from pyspark.sql.connect.proto.types_pb2 import * +from pyspark.sql.connect.proto.commands_pb2 import * +from pyspark.sql.connect.proto.expressions_pb2 import * +from pyspark.sql.connect.proto.relations_pb2 import * diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py new file mode 100644 index 000000000000..3adb77f77d60 --- /dev/null +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/base.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from pyspark.sql.connect.proto import ( + commands_pb2 as spark_dot_connect_dot_commands__pb2, +) +from pyspark.sql.connect.proto import ( + relations_pb2 as spark_dot_connect_dot_relations__pb2, +) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xdb\x01\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x43\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName"\xc4\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12?\n\tcsv_batch\x18\x03 \x01(\x0b\x32 .spark.connect.Response.CSVBatchH\x00R\x08\x63svBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a;\n\x08\x43SVBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\tR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42M\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.base_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = ( + b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto" + ) + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" + _PLAN._serialized_start = 104 + _PLAN._serialized_end = 220 + _REQUEST._serialized_start = 223 + _REQUEST._serialized_end = 442 + _REQUEST_USERCONTEXT._serialized_start = 375 + _REQUEST_USERCONTEXT._serialized_end = 442 + _RESPONSE._serialized_start = 445 + _RESPONSE._serialized_end = 1409 + _RESPONSE_ARROWBATCH._serialized_start = 671 + _RESPONSE_ARROWBATCH._serialized_end = 846 + _RESPONSE_CSVBATCH._serialized_start = 848 + _RESPONSE_CSVBATCH._serialized_end = 907 + _RESPONSE_METRICS._serialized_start = 910 + _RESPONSE_METRICS._serialized_end = 1394 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 994 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1304 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1192 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1304 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1306 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1394 + _ANALYZERESPONSE._serialized_start = 1412 + _ANALYZERESPONSE._serialized_end = 1567 + _SPARKCONNECTSERVICE._serialized_start = 1570 + _SPARKCONNECTSERVICE._serialized_end = 1732 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py new file mode 100644 index 000000000000..77307603f6eb --- /dev/null +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 + + +class SparkConnectServiceStub(object): + """Main interface for the SparkConnect service.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ExecutePlan = channel.unary_stream( + "/spark.connect.SparkConnectService/ExecutePlan", + request_serializer=spark_dot_connect_dot_base__pb2.Request.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.Response.FromString, + ) + self.AnalyzePlan = channel.unary_unary( + "/spark.connect.SparkConnectService/AnalyzePlan", + request_serializer=spark_dot_connect_dot_base__pb2.Request.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzeResponse.FromString, + ) + + +class SparkConnectServiceServicer(object): + """Main interface for the SparkConnect service.""" + + def ExecutePlan(self, request, context): + """Executes a request that contains the query and returns a stream of [[Response]].""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def AnalyzePlan(self, request, context): + """Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_SparkConnectServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "ExecutePlan": grpc.unary_stream_rpc_method_handler( + servicer.ExecutePlan, + request_deserializer=spark_dot_connect_dot_base__pb2.Request.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.Response.SerializeToString, + ), + "AnalyzePlan": grpc.unary_unary_rpc_method_handler( + servicer.AnalyzePlan, + request_deserializer=spark_dot_connect_dot_base__pb2.Request.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.AnalyzeResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "spark.connect.SparkConnectService", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class SparkConnectService(object): + """Main interface for the SparkConnect service.""" + + @staticmethod + def ExecutePlan( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/spark.connect.SparkConnectService/ExecutePlan", + spark_dot_connect_dot_base__pb2.Request.SerializeToString, + spark_dot_connect_dot_base__pb2.Response.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def AnalyzePlan( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/spark.connect.SparkConnectService/AnalyzePlan", + spark_dot_connect_dot_base__pb2.Request.SerializeToString, + spark_dot_connect_dot_base__pb2.AnalyzeResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py new file mode 100644 index 000000000000..f5bb6ad5628a --- /dev/null +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/commands.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"i\n\x07\x43ommand\x12N\n\x0f\x63reate_function\x18\x01 \x01(\x0b\x32#.spark.connect.CreateScalarFunctionH\x00R\x0e\x63reateFunctionB\x0e\n\x0c\x63ommand_type"\x8f\x04\n\x14\x43reateScalarFunction\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x12P\n\x08language\x18\x02 \x01(\x0e\x32\x34.spark.connect.CreateScalarFunction.FunctionLanguageR\x08language\x12\x1c\n\ttemporary\x18\x03 \x01(\x08R\ttemporary\x12:\n\x0e\x61rgument_types\x18\x04 \x03(\x0b\x32\x13.spark.connect.TypeR\rargumentTypes\x12\x34\n\x0breturn_type\x18\x05 \x01(\x0b\x32\x13.spark.connect.TypeR\nreturnType\x12\x31\n\x13serialized_function\x18\x06 \x01(\x0cH\x00R\x12serializedFunction\x12\'\n\x0eliteral_string\x18\x07 \x01(\tH\x00R\rliteralString"\x8b\x01\n\x10\x46unctionLanguage\x12!\n\x1d\x46UNCTION_LANGUAGE_UNSPECIFIED\x10\x00\x12\x19\n\x15\x46UNCTION_LANGUAGE_SQL\x10\x01\x12\x1c\n\x18\x46UNCTION_LANGUAGE_PYTHON\x10\x02\x12\x1b\n\x17\x46UNCTION_LANGUAGE_SCALA\x10\x03\x42\x15\n\x13\x66unction_definitionBM\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.commands_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = ( + b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto" + ) + _COMMAND._serialized_start = 74 + _COMMAND._serialized_end = 179 + _CREATESCALARFUNCTION._serialized_start = 182 + _CREATESCALARFUNCTION._serialized_end = 709 + _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_start = 547 + _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_end = 686 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py new file mode 100644 index 000000000000..16f3325d4146 --- /dev/null +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/expressions.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xd8\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x1a\x97\x10\n\x07Literal\x12\x1a\n\x07\x62oolean\x18\x01 \x01(\x08H\x00R\x07\x62oolean\x12\x10\n\x02i8\x18\x02 \x01(\x05H\x00R\x02i8\x12\x12\n\x03i16\x18\x03 \x01(\x05H\x00R\x03i16\x12\x12\n\x03i32\x18\x05 \x01(\x05H\x00R\x03i32\x12\x12\n\x03i64\x18\x07 \x01(\x03H\x00R\x03i64\x12\x14\n\x04\x66p32\x18\n \x01(\x02H\x00R\x04\x66p32\x12\x14\n\x04\x66p64\x18\x0b \x01(\x01H\x00R\x04\x66p64\x12\x18\n\x06string\x18\x0c \x01(\tH\x00R\x06string\x12\x18\n\x06\x62inary\x18\r \x01(\x0cH\x00R\x06\x62inary\x12\x1e\n\ttimestamp\x18\x0e \x01(\x03H\x00R\ttimestamp\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x14\n\x04time\x18\x11 \x01(\x03H\x00R\x04time\x12l\n\x16interval_year_to_month\x18\x13 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalYearToMonthH\x00R\x13intervalYearToMonth\x12l\n\x16interval_day_to_second\x18\x14 \x01(\x0b\x32\x35.spark.connect.Expression.Literal.IntervalDayToSecondH\x00R\x13intervalDayToSecond\x12\x1f\n\nfixed_char\x18\x15 \x01(\tH\x00R\tfixedChar\x12\x46\n\x08var_char\x18\x16 \x01(\x0b\x32).spark.connect.Expression.Literal.VarCharH\x00R\x07varChar\x12#\n\x0c\x66ixed_binary\x18\x17 \x01(\x0cH\x00R\x0b\x66ixedBinary\x12\x45\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x42\n\x06struct\x18\x19 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x1a \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12#\n\x0ctimestamp_tz\x18\x1b \x01(\x03H\x00R\x0btimestampTz\x12\x14\n\x04uuid\x18\x1c \x01(\x0cH\x00R\x04uuid\x12)\n\x04null\x18\x1d \x01(\x0b\x32\x13.spark.connect.TypeH\x00R\x04null\x12<\n\x04list\x18\x1e \x01(\x0b\x32&.spark.connect.Expression.Literal.ListH\x00R\x04list\x12\x39\n\nempty_list\x18\x1f \x01(\x0b\x32\x18.spark.connect.Type.ListH\x00R\temptyList\x12\x36\n\tempty_map\x18 \x01(\x0b\x32\x17.spark.connect.Type.MapH\x00R\x08\x65mptyMap\x12R\n\x0cuser_defined\x18! \x01(\x0b\x32-.spark.connect.Expression.Literal.UserDefinedH\x00R\x0buserDefined\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1a\x37\n\x07VarChar\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12\x16\n\x06length\x18\x02 \x01(\rR\x06length\x1aS\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value\x12\x1c\n\tprecision\x18\x02 \x01(\x05R\tprecision\x12\x14\n\x05scale\x18\x03 \x01(\x05R\x05scale\x1a\xce\x01\n\x03Map\x12M\n\nkey_values\x18\x01 \x03(\x0b\x32..spark.connect.Expression.Literal.Map.KeyValueR\tkeyValues\x1ax\n\x08KeyValue\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value\x1a\x43\n\x13IntervalYearToMonth\x12\x14\n\x05years\x18\x01 \x01(\x05R\x05years\x12\x16\n\x06months\x18\x02 \x01(\x05R\x06months\x1ag\n\x13IntervalDayToSecond\x12\x12\n\x04\x64\x61ys\x18\x01 \x01(\x05R\x04\x64\x61ys\x12\x18\n\x07seconds\x18\x02 \x01(\x05R\x07seconds\x12"\n\x0cmicroseconds\x18\x03 \x01(\x05R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x41\n\x04List\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a`\n\x0bUserDefined\x12%\n\x0etype_reference\x18\x01 \x01(\rR\rtypeReference\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05valueB\x0e\n\x0cliteral_type\x1a+\n\x13UnresolvedAttribute\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x1a\x63\n\x12UnresolvedFunction\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05parts\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpressionB\x0b\n\texpr_typeBM\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.expressions_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = ( + b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto" + ) + _EXPRESSION._serialized_start = 105 + _EXPRESSION._serialized_end = 2753 + _EXPRESSION_LITERAL._serialized_start = 471 + _EXPRESSION_LITERAL._serialized_end = 2542 + _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1769 + _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1824 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1826 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1909 + _EXPRESSION_LITERAL_MAP._serialized_start = 1912 + _EXPRESSION_LITERAL_MAP._serialized_end = 2118 + _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 1998 + _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2118 + _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2120 + _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2187 + _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2189 + _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2292 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 2294 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 2361 + _EXPRESSION_LITERAL_LIST._serialized_start = 2363 + _EXPRESSION_LITERAL_LIST._serialized_end = 2428 + _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2430 + _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2526 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2544 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2587 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2589 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2688 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2690 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2740 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py new file mode 100644 index 000000000000..b9f74dc23806 --- /dev/null +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/relations.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from pyspark.sql.connect.proto import ( + expressions_pb2 as spark_dot_connect_dot_expressions__pb2, +) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa6\x04\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05\x66\x65tch\x18\x08 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SqlH\x00R\x03sql\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03Sql\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"z\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a"\n\nNamedTable\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05partsB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd8\x02\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12)\n\x02on\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x02on\x12.\n\x03how\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x03how"\x98\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x13\n\x0fJOIN_TYPE_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x12\n\x0eJOIN_TYPE_ANTI\x10\x05"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"d\n\x05\x46\x65tch\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit\x12\x16\n\x06offset\x18\x03 \x01(\x05R\x06offset"\x8b\x04\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12I\n\rgrouping_sets\x18\x02 \x03(\x0b\x32$.spark.connect.Aggregate.GroupingSetR\x0cgroupingSets\x12<\n\x08measures\x18\x03 \x03(\x0b\x32 .spark.connect.Aggregate.MeasureR\x08measures\x1a]\n\x0bGroupingSet\x12N\n\x15\x61ggregate_expressions\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x1a\x84\x01\n\x07Measure\x12\x46\n\x08\x66unction\x18\x01 \x01(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x08\x66unction\x12\x31\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x06\x66ilter\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42M\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.relations_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = ( + b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto" + ) + _RELATION._serialized_start = 82 + _RELATION._serialized_end = 632 + _UNKNOWN._serialized_start = 634 + _UNKNOWN._serialized_end = 643 + _RELATIONCOMMON._serialized_start = 645 + _RELATIONCOMMON._serialized_end = 716 + _SQL._serialized_start = 718 + _SQL._serialized_end = 745 + _READ._serialized_start = 747 + _READ._serialized_end = 869 + _READ_NAMEDTABLE._serialized_start = 822 + _READ_NAMEDTABLE._serialized_end = 856 + _PROJECT._serialized_start = 871 + _PROJECT._serialized_end = 988 + _FILTER._serialized_start = 990 + _FILTER._serialized_end = 1102 + _JOIN._serialized_start = 1105 + _JOIN._serialized_end = 1449 + _JOIN_JOINTYPE._serialized_start = 1297 + _JOIN_JOINTYPE._serialized_end = 1449 + _UNION._serialized_start = 1452 + _UNION._serialized_end = 1657 + _UNION_UNIONTYPE._serialized_start = 1573 + _UNION_UNIONTYPE._serialized_end = 1657 + _FETCH._serialized_start = 1659 + _FETCH._serialized_end = 1759 + _AGGREGATE._serialized_start = 1762 + _AGGREGATE._serialized_end = 2285 + _AGGREGATE_GROUPINGSET._serialized_start = 1959 + _AGGREGATE_GROUPINGSET._serialized_end = 2052 + _AGGREGATE_MEASURE._serialized_start = 2055 + _AGGREGATE_MEASURE._serialized_end = 2187 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2189 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2285 + _SORT._serialized_start = 2288 + _SORT._serialized_end = 2790 + _SORT_SORTFIELD._serialized_start = 2408 + _SORT_SORTFIELD._serialized_end = 2596 + _SORT_SORTDIRECTION._serialized_start = 2598 + _SORT_SORTDIRECTION._serialized_end = 2706 + _SORT_SORTNULLS._serialized_start = 2708 + _SORT_SORTNULLS._serialized_end = 2790 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py new file mode 100644 index 000000000000..27247ca4e0c8 --- /dev/null +++ b/python/pyspark/sql/connect/proto/types_pb2.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/types.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xea%\n\x04Type\x12\x31\n\x04\x62ool\x18\x01 \x01(\x0b\x32\x1b.spark.connect.Type.BooleanH\x00R\x04\x62ool\x12(\n\x02i8\x18\x02 \x01(\x0b\x32\x16.spark.connect.Type.I8H\x00R\x02i8\x12+\n\x03i16\x18\x03 \x01(\x0b\x32\x17.spark.connect.Type.I16H\x00R\x03i16\x12+\n\x03i32\x18\x05 \x01(\x0b\x32\x17.spark.connect.Type.I32H\x00R\x03i32\x12+\n\x03i64\x18\x07 \x01(\x0b\x32\x17.spark.connect.Type.I64H\x00R\x03i64\x12.\n\x04\x66p32\x18\n \x01(\x0b\x32\x18.spark.connect.Type.FP32H\x00R\x04\x66p32\x12.\n\x04\x66p64\x18\x0b \x01(\x0b\x32\x18.spark.connect.Type.FP64H\x00R\x04\x66p64\x12\x34\n\x06string\x18\x0c \x01(\x0b\x32\x1a.spark.connect.Type.StringH\x00R\x06string\x12\x34\n\x06\x62inary\x18\r \x01(\x0b\x32\x1a.spark.connect.Type.BinaryH\x00R\x06\x62inary\x12=\n\ttimestamp\x18\x0e \x01(\x0b\x32\x1d.spark.connect.Type.TimestampH\x00R\ttimestamp\x12.\n\x04\x64\x61te\x18\x10 \x01(\x0b\x32\x18.spark.connect.Type.DateH\x00R\x04\x64\x61te\x12.\n\x04time\x18\x11 \x01(\x0b\x32\x18.spark.connect.Type.TimeH\x00R\x04time\x12G\n\rinterval_year\x18\x13 \x01(\x0b\x32 .spark.connect.Type.IntervalYearH\x00R\x0cintervalYear\x12\x44\n\x0cinterval_day\x18\x14 \x01(\x0b\x32\x1f.spark.connect.Type.IntervalDayH\x00R\x0bintervalDay\x12\x44\n\x0ctimestamp_tz\x18\x1d \x01(\x0b\x32\x1f.spark.connect.Type.TimestampTZH\x00R\x0btimestampTz\x12.\n\x04uuid\x18 \x01(\x0b\x32\x18.spark.connect.Type.UUIDH\x00R\x04uuid\x12>\n\nfixed_char\x18\x15 \x01(\x0b\x32\x1d.spark.connect.Type.FixedCharH\x00R\tfixedChar\x12\x37\n\x07varchar\x18\x16 \x01(\x0b\x32\x1b.spark.connect.Type.VarCharH\x00R\x07varchar\x12\x44\n\x0c\x66ixed_binary\x18\x17 \x01(\x0b\x32\x1f.spark.connect.Type.FixedBinaryH\x00R\x0b\x66ixedBinary\x12\x37\n\x07\x64\x65\x63imal\x18\x18 \x01(\x0b\x32\x1b.spark.connect.Type.DecimalH\x00R\x07\x64\x65\x63imal\x12\x34\n\x06struct\x18\x19 \x01(\x0b\x32\x1a.spark.connect.Type.StructH\x00R\x06struct\x12.\n\x04list\x18\x1b \x01(\x0b\x32\x18.spark.connect.Type.ListH\x00R\x04list\x12+\n\x03map\x18\x1c \x01(\x0b\x32\x17.spark.connect.Type.MapH\x00R\x03map\x12?\n\x1buser_defined_type_reference\x18\x1f \x01(\rH\x00R\x18userDefinedTypeReference\x1a\x86\x01\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x81\x01\n\x02I8\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x82\x01\n\x03I16\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x82\x01\n\x03I32\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x82\x01\n\x03I64\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04\x46P32\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04\x46P64\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x85\x01\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x85\x01\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x88\x01\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04Time\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x8a\x01\n\x0bTimestampTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x8b\x01\n\x0cIntervalYear\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x8a\x01\n\x0bIntervalDay\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x83\x01\n\x04UUID\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x02 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xa0\x01\n\tFixedChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\x9e\x01\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xa2\x01\n\x0b\x46ixedBinary\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xba\x01\n\x07\x44\x65\x63imal\x12\x14\n\x05scale\x18\x01 \x01(\x05R\x05scale\x12\x1c\n\tprecision\x18\x02 \x01(\x05R\tprecision\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x04 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xb0\x01\n\x06Struct\x12)\n\x05types\x18\x01 \x03(\x0b\x32\x13.spark.connect.TypeR\x05types\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xac\x01\n\x04List\x12\'\n\x04type\x18\x01 \x01(\x0b\x32\x13.spark.connect.TypeR\x04type\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x03 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability\x1a\xd4\x01\n\x03Map\x12%\n\x03key\x18\x01 \x01(\x0b\x32\x13.spark.connect.TypeR\x03key\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x13.spark.connect.TypeR\x05value\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x12\x41\n\x0bnullability\x18\x04 \x01(\x0e\x32\x1f.spark.connect.Type.NullabilityR\x0bnullability"^\n\x0bNullability\x12\x1b\n\x17NULLABILITY_UNSPECIFIED\x10\x00\x12\x18\n\x14NULLABILITY_NULLABLE\x10\x01\x12\x18\n\x14NULLABILITY_REQUIRED\x10\x02\x42\x06\n\x04kindBM\n\x1eorg.apache.spark.connect.protoP\x01Z)github.com/databricks/spark-connect/protob\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spark.connect.types_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = ( + b"\n\036org.apache.spark.connect.protoP\001Z)github.com/databricks/spark-connect/proto" + ) + _TYPE._serialized_start = 45 + _TYPE._serialized_end = 4887 + _TYPE_BOOLEAN._serialized_start = 1366 + _TYPE_BOOLEAN._serialized_end = 1500 + _TYPE_I8._serialized_start = 1503 + _TYPE_I8._serialized_end = 1632 + _TYPE_I16._serialized_start = 1635 + _TYPE_I16._serialized_end = 1765 + _TYPE_I32._serialized_start = 1768 + _TYPE_I32._serialized_end = 1898 + _TYPE_I64._serialized_start = 1901 + _TYPE_I64._serialized_end = 2031 + _TYPE_FP32._serialized_start = 2034 + _TYPE_FP32._serialized_end = 2165 + _TYPE_FP64._serialized_start = 2168 + _TYPE_FP64._serialized_end = 2299 + _TYPE_STRING._serialized_start = 2302 + _TYPE_STRING._serialized_end = 2435 + _TYPE_BINARY._serialized_start = 2438 + _TYPE_BINARY._serialized_end = 2571 + _TYPE_TIMESTAMP._serialized_start = 2574 + _TYPE_TIMESTAMP._serialized_end = 2710 + _TYPE_DATE._serialized_start = 2713 + _TYPE_DATE._serialized_end = 2844 + _TYPE_TIME._serialized_start = 2847 + _TYPE_TIME._serialized_end = 2978 + _TYPE_TIMESTAMPTZ._serialized_start = 2981 + _TYPE_TIMESTAMPTZ._serialized_end = 3119 + _TYPE_INTERVALYEAR._serialized_start = 3122 + _TYPE_INTERVALYEAR._serialized_end = 3261 + _TYPE_INTERVALDAY._serialized_start = 3264 + _TYPE_INTERVALDAY._serialized_end = 3402 + _TYPE_UUID._serialized_start = 3405 + _TYPE_UUID._serialized_end = 3536 + _TYPE_FIXEDCHAR._serialized_start = 3539 + _TYPE_FIXEDCHAR._serialized_end = 3699 + _TYPE_VARCHAR._serialized_start = 3702 + _TYPE_VARCHAR._serialized_end = 3860 + _TYPE_FIXEDBINARY._serialized_start = 3863 + _TYPE_FIXEDBINARY._serialized_end = 4025 + _TYPE_DECIMAL._serialized_start = 4028 + _TYPE_DECIMAL._serialized_end = 4214 + _TYPE_STRUCT._serialized_start = 4217 + _TYPE_STRUCT._serialized_end = 4393 + _TYPE_LIST._serialized_start = 4396 + _TYPE_LIST._serialized_end = 4568 + _TYPE_MAP._serialized_start = 4571 + _TYPE_MAP._serialized_end = 4783 + _TYPE_NULLABILITY._serialized_start = 4785 + _TYPE_NULLABILITY._serialized_end = 4879 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py new file mode 100644 index 000000000000..4aac4eed08c5 --- /dev/null +++ b/python/pyspark/sql/connect/readwriter.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.connect.data_frame import DataFrame +from pyspark.sql.connect.plan import Read + + +class DataFrameReader: + """ + TODO(SPARK-40539) Achieve parity with PySpark. + """ + + def __init__(self, client): + self._client = client + + def table(self, tableName: str) -> "DataFrame": + df = DataFrame.withPlan(Read(tableName), self._client) + return df diff --git a/python/pyspark/sql/tests/connect/__init__.py b/python/pyspark/sql/tests/connect/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/python/pyspark/sql/tests/connect/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/connect/test_column_expressions.py b/python/pyspark/sql/tests/connect/test_column_expressions.py new file mode 100644 index 000000000000..1f067bf79956 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_column_expressions.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.tests.connect.utils import PlanOnlyTestFixture + +import pyspark.sql.connect as c +import pyspark.sql.connect.plan as p +import pyspark.sql.connect.column as col + +import pyspark.sql.connect.functions as fun + + +class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): + def test_simple_column_expressions(self): + df = c.DataFrame.withPlan(p.Read("table")) + + c1 = df.col_name + assert isinstance(c1, col.ColumnRef) + c2 = df["col_name"] + assert isinstance(c2, col.ColumnRef) + c3 = fun.col("col_name") + assert isinstance(c3, col.ColumnRef) + + # All Protos should be identical + cp1 = c1.to_plan(None) + cp2 = c2.to_plan(None) + cp3 = c3.to_plan(None) + + assert cp1 is not None + assert cp1 == cp2 == cp3 + + def test_column_literals(self): + df = c.DataFrame.withPlan(p.Read("table")) + lit_df = df.select(fun.lit(10)) + self.assertIsNotNone(lit_df._plan.collect(None)) + + self.assertIsNotNone(fun.lit(10).to_plan(None)) + plan = fun.lit(10).to_plan(None) + self.assertIs(plan.literal.i32, 10) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_column_expressions import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_plan_only.py b/python/pyspark/sql/tests/connect/test_plan_only.py new file mode 100644 index 000000000000..9e6d30cbe1fd --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_plan_only.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.connect import DataFrame +from pyspark.sql.connect.plan import Read +from pyspark.sql.connect.function_builder import UserDefinedFunction, udf +from pyspark.sql.tests.connect.utils.spark_connect_test_utils import PlanOnlyTestFixture +from pyspark.sql.types import StringType + + +class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): + """These test cases exercise the interface to the proto plan + generation but do not call Spark.""" + + def test_simple_project(self): + def read_table(x): + return DataFrame.withPlan(Read(x), self.connect) + + self.connect.set_hook("readTable", read_table) + + plan = self.connect.readTable(self.tbl_name)._plan.collect(self.connect) + self.assertIsNotNone(plan.root, "Root relation must be set") + self.assertIsNotNone(plan.root.read) + + def test_simple_udf(self): + def udf_mock(*args, **kwargs): + return "internal_name" + + self.connect.set_hook("register_udf", udf_mock) + + u = udf(lambda x: "Martin", StringType()) + self.assertIsNotNone(u) + expr = u("ThisCol", "ThatCol", "OtherCol") + self.assertTrue(isinstance(expr, UserDefinedFunction)) + u_plan = expr.to_plan(self.connect) + assert u_plan is not None + + def test_all_the_plans(self): + def read_table(x): + return DataFrame.withPlan(Read(x), self.connect) + + self.connect.set_hook("readTable", read_table) + + df = self.connect.readTable(self.tbl_name) + df = df.select(df.col1).filter(df.col2 == 2).sort(df.col3.asc()) + plan = df._plan.collect(self.connect) + self.assertIsNotNone(plan.root, "Root relation must be set") + self.assertIsNotNone(plan.root.read) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_plan_only import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_select_ops.py b/python/pyspark/sql/tests/connect/test_select_ops.py new file mode 100644 index 000000000000..818f82b33e86 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_select_ops.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.tests.connect.utils import PlanOnlyTestFixture +from pyspark.sql.connect import DataFrame +from pyspark.sql.connect.functions import col +from pyspark.sql.connect.plan import Read, InputValidationError + + +class SparkConnectSelectOpsSuite(PlanOnlyTestFixture): + def test_select_with_literal(self): + df = DataFrame.withPlan(Read("table")) + self.assertIsNotNone(df.select(col("name"))._plan.collect()) + self.assertRaises(InputValidationError, df.select, "name") + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_select_ops import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_spark_connect.py b/python/pyspark/sql/tests/connect/test_spark_connect.py new file mode 100644 index 000000000000..cac17b7397dc --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_spark_connect.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any +import uuid +import unittest +import tempfile + +from pyspark.sql import SparkSession, Row +from pyspark.sql.connect.client import RemoteSparkSession +from pyspark.sql.connect.function_builder import udf +from pyspark.testing.utils import ReusedPySparkTestCase + + +class SparkConnectSQLTestCase(ReusedPySparkTestCase): + """Parent test fixture class for all Spark Connect related + test cases.""" + + @classmethod + def setUpClass(cls: Any) -> None: + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True + # Create the new Spark Session + cls.spark = SparkSession(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + # Load test data + cls.spark_connect_test_data() + + @classmethod + def spark_connect_test_data(cls: Any) -> None: + # Setup Remote Spark Session + cls.tbl_name = f"tbl{uuid.uuid4()}".replace("-", "") + cls.connect = RemoteSparkSession(user_id="test_user") + df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) + # Since we might create multiple Spark sessions, we need to creata global temporary view + # that is specifically maintained in the "global_temp" schema. + df.write.saveAsTable(cls.tbl_name) + + +class SparkConnectTests(SparkConnectSQLTestCase): + def test_simple_read(self) -> None: + """Tests that we can access the Spark Connect GRPC service locally.""" + df = self.connect.read.table(self.tbl_name) + data = df.limit(10).collect() + # Check that the limit is applied + assert len(data.index) == 10 + + def test_simple_udf(self) -> None: + def conv_udf(x) -> str: + return "Martin" + + u = udf(conv_udf) + df = self.connect.read.table(self.tbl_name) + result = df.select(u(df.id)).collect() + assert result is not None + + def test_simple_explain_string(self) -> None: + df = self.connect.read.table(self.tbl_name).limit(10) + result = df.explain() + assert len(result) > 0 + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_spark_connect import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/utils/__init__.py b/python/pyspark/sql/tests/connect/utils/__init__.py new file mode 100644 index 000000000000..b95812c8a297 --- /dev/null +++ b/python/pyspark/sql/tests/connect/utils/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.tests.connect.utils.spark_connect_test_utils import ( # noqa: F401 + PlanOnlyTestFixture, # noqa: F401 +) # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py b/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py new file mode 100644 index 000000000000..34bf49db4945 --- /dev/null +++ b/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Dict +import functools +import unittest +import uuid + + +class MockRemoteSession: + def __init__(self) -> None: + self.hooks: Dict[str, Any] = {} + + def set_hook(self, name: str, hook: Any) -> None: + self.hooks[name] = hook + + def __getattr__(self, item: str) -> Any: + if item not in self.hooks: + raise LookupError(f"{item} is not defined as a method hook in MockRemoteSession") + return functools.partial(self.hooks[item]) + + +class PlanOnlyTestFixture(unittest.TestCase): + @classmethod + def setUpClass(cls: Any) -> None: + cls.connect = MockRemoteSession() + cls.tbl_name = f"tbl{uuid.uuid4()}".replace("-", "") diff --git a/python/run-tests.py b/python/run-tests.py index d43ed8e96f40..af4c6f1c94be 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -110,6 +110,15 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_ metastore_dir = os.path.join(metastore_dir, str(uuid.uuid4())) os.mkdir(metastore_dir) + # Check if we should enable the SparkConnectPlugin + additional_config = [] + if test_name.startswith("pyspark.sql.tests.connect"): + # Adding Spark Connect JAR and Config + additional_config += [ + "--conf", + "spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin" + ] + # Also override the JVM's temp directory by setting driver and executor options. java_options = "-Djava.io.tmpdir={0}".format(tmp_dir) java_options = java_options + " -Dio.netty.tryReflectionSetAccessible=true -Xss4M" @@ -117,8 +126,10 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_ "--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options), "--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options), "--conf", "spark.sql.warehouse.dir='{0}'".format(metastore_dir), - "pyspark-shell" ] + spark_args += additional_config + spark_args += ["pyspark-shell"] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) output_prefix = get_valid_filename(pyspark_python + "__" + test_name + "__").lstrip("_")