From 84ff08bd7e7cbd0b94d8fcf310c9e67fead4a9a0 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 23 Nov 2022 12:53:35 +0800 Subject: [PATCH 1/3] init --- python/pyspark/sql/connect/_typing.py | 35 ++++++++++++++++--- python/pyspark/sql/connect/client.py | 25 ++++++------- python/pyspark/sql/connect/column.py | 9 ++--- python/pyspark/sql/connect/dataframe.py | 12 +++---- .../pyspark/sql/connect/function_builder.py | 10 ++++-- python/pyspark/sql/connect/plan.py | 4 +-- python/pyspark/sql/connect/readwriter.py | 5 +-- .../pyspark/sql/connect/typing/__init__.pyi | 35 ------------------- 8 files changed, 61 insertions(+), 74 deletions(-) delete mode 100644 python/pyspark/sql/connect/typing/__init__.pyi diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 4e69b2e4aa5ef..acaa51e4403e2 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -14,8 +14,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Union -from datetime import date, time, datetime -PrimitiveType = Union[str, int, bool, float] -LiteralType = Union[PrimitiveType, Union[date, time, datetime]] +from typing_extensions import Protocol +from typing import Union, Optional +import datetime +import decimal + +from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column +from pyspark.sql.connect.function_builder import UserDefinedFunction + +ExpressionOrString = Union[Expression, str] + +ColumnOrName = Union[Column, str] + +PrimitiveType = Union[bool, float, int, str] + +OptionalPrimitiveType = Optional[PrimitiveType] + +LiteralType = PrimitiveType + +DecimalLiteral = decimal.Decimal + +DateTimeLiteral = Union[datetime.datetime, datetime.date] + + +class FunctionBuilderCallable(Protocol): + def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: + ... + + +class UserDefinedFunctionCallable(Protocol): + def __call__(self, *_: ColumnOrName) -> UserDefinedFunction: + ... diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index fdcf34b7a47e8..deb9ef6f3be6f 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -18,7 +18,6 @@ import logging import os -import typing import urllib.parse import uuid @@ -35,9 +34,7 @@ from pyspark.sql.connect.plan import SQL, Range from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType -from typing import Optional, Any, Union - -NumericType = typing.Union[int, float] +from typing import Iterable, Optional, Any, Union, List, Tuple, Dict logging.basicConfig(level=logging.INFO) @@ -74,7 +71,7 @@ def __init__(self, url: str) -> None: # Python's built-in parser. tmp_url = "http" + url[2:] self.url = urllib.parse.urlparse(tmp_url) - self.params: typing.Dict[str, str] = {} + self.params: Dict[str, str] = {} if len(self.url.path) > 0 and self.url.path != "/": raise AttributeError( f"Path component for connection URI must be empty: {self.url.path}" @@ -102,7 +99,7 @@ def _extract_attributes(self) -> None: f"Target destination {self.url.netloc} does not match ':' pattern" ) - def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]: + def metadata(self) -> Iterable[Tuple[str, str]]: """ Builds the GRPC specific metadata list to be injected into the request. All parameters will be converted to metadata except ones that are explicitly used @@ -198,7 +195,7 @@ def toChannel(self) -> grpc.Channel: class MetricValue: - def __init__(self, name: str, value: NumericType, type: str): + def __init__(self, name: str, value: Union[int, float], type: str): self._name = name self._type = type self._value = value @@ -211,7 +208,7 @@ def name(self) -> str: return self._name @property - def value(self) -> NumericType: + def value(self) -> Union[int, float]: return self._value @property @@ -220,7 +217,7 @@ def metric_type(self) -> str: class PlanMetrics: - def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]): + def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]): self._name = name self._id = id self._parent_id = parent @@ -242,7 +239,7 @@ def parent_plan_id(self) -> int: return self._parent_id @property - def metrics(self) -> typing.List[MetricValue]: + def metrics(self) -> List[MetricValue]: return self._metrics @@ -252,7 +249,7 @@ def __init__(self, schema: pb2.DataType, explain: str): self.explain_string = explain @classmethod - def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": + def fromProto(cls, pb: Any) -> "AnalyzeResult": return AnalyzeResult(pb.schema, pb.explain_string) @@ -306,9 +303,7 @@ def register_udf( self._execute_and_fetch(req) return name - def _build_metrics( - self, metrics: "pb2.ExecutePlanResponse.Metrics" - ) -> typing.List[PlanMetrics]: + def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]: return [ PlanMetrics( x.name, @@ -450,7 +445,7 @@ def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFram return rd.read_pandas() return None - def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]: + def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]: import pandas as pd m: Optional[pb2.ExecutePlanResponse.Metrics] = None diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c4ffc54c20b7d..fb42cddaf6906 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -15,14 +15,15 @@ # limitations under the License. # import uuid -from typing import cast, get_args, TYPE_CHECKING, Callable, Any +from typing import cast, TYPE_CHECKING, Callable, Any import json import decimal import datetime import pyspark.sql.connect.proto as proto -from pyspark.sql.connect._typing import PrimitiveType + +primitive_types = (bool, float, int, str) if TYPE_CHECKING: from pyspark.sql.connect.client import RemoteSparkSession @@ -33,7 +34,7 @@ def _bin_op( name: str, doc: str = "binary function", reverse: bool = False ) -> Callable[["Column", Any], "Expression"]: def _(self: "Column", other: Any) -> "Expression": - if isinstance(other, get_args(PrimitiveType)): + if isinstance(other, primitive_types): other = LiteralExpression(other) if not reverse: return ScalarFunctionExpression(name, self, other) @@ -70,7 +71,7 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override] """Returns a binary expression with the current column as the left side and the other expression as the right side. """ - if isinstance(other, get_args(PrimitiveType)): + if isinstance(other, primitive_types): other = LiteralExpression(other) return ScalarFunctionExpression("==", self, other) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 82dc1f6a558dd..8c6e6280b06bc 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -44,11 +44,9 @@ ) if TYPE_CHECKING: - from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString, LiteralType + from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType from pyspark.sql.connect.client import RemoteSparkSession -ColumnOrName = Union[Column, str] - class GroupingFrame(object): @@ -308,7 +306,7 @@ def distinct(self) -> "DataFrame": plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) - def drop(self, *cols: "ColumnOrString") -> "DataFrame": + def drop(self, *cols: "ColumnOrName") -> "DataFrame": _cols = list(cols) if any(not isinstance(c, (str, Column)) for c in _cols): raise TypeError( @@ -342,7 +340,7 @@ def first(self) -> Optional[Row]: """ return self.head() - def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame: + def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame: return GroupingFrame(self, *cols) @overload @@ -414,13 +412,13 @@ def limit(self, n: int) -> "DataFrame": def offset(self, n: int) -> "DataFrame": return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session) - def sort(self, *cols: "ColumnOrString") -> "DataFrame": + def sort(self, *cols: "ColumnOrName") -> "DataFrame": """Sort by a specific column""" return DataFrame.withPlan( plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session ) - def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame": + def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame": """Sort within each partition by a specific column""" return DataFrame.withPlan( plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index e116e49395487..4a2688d6a0daf 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -28,9 +28,13 @@ if TYPE_CHECKING: - from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString + from pyspark.sql.connect._typing import ( + ColumnOrName, + ExpressionOrString, + FunctionBuilderCallable, + UserDefinedFunctionCallable, + ) from pyspark.sql.connect.client import RemoteSparkSession - from pyspark.sql.connect.typing import FunctionBuilderCallable, UserDefinedFunctionCallable def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression: @@ -103,7 +107,7 @@ def __str__(self) -> str: def _create_udf( function: Any, return_type: Union[str, pyspark.sql.types.DataType] ) -> "UserDefinedFunctionCallable": - def wrapper(*cols: "ColumnOrString") -> UserDefinedFunction: + def wrapper(*cols: "ColumnOrName") -> UserDefinedFunction: return UserDefinedFunction(func=function, return_type=return_type, args=cols) return wrapper diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index ffb0ce080b30f..8aadc3dc4fa5c 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: - from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString + from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString from pyspark.sql.connect.client import RemoteSparkSession @@ -58,7 +58,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression: return exp def to_attr_or_expression( - self, col: "ColumnOrString", session: "RemoteSparkSession" + self, col: "ColumnOrName", session: "RemoteSparkSession" ) -> proto.Expression: """Returns either an instance of an unresolved attribute or the serialized expression value of the column.""" diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 66e48eeab76be..27aa023ae474f 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -18,17 +18,14 @@ from typing import Dict, Optional -from pyspark.sql.connect.column import PrimitiveType from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import Read, DataSource from pyspark.sql.utils import to_str - -OptionalPrimitiveType = Optional[PrimitiveType] - from typing import TYPE_CHECKING if TYPE_CHECKING: + from pyspark.sql.connect._typing import OptionalPrimitiveType from pyspark.sql.connect.client import RemoteSparkSession diff --git a/python/pyspark/sql/connect/typing/__init__.pyi b/python/pyspark/sql/connect/typing/__init__.pyi deleted file mode 100644 index 43cc28701daef..0000000000000 --- a/python/pyspark/sql/connect/typing/__init__.pyi +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing_extensions import Protocol -from typing import Union -from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column -from pyspark.sql.connect.function_builder import UserDefinedFunction - -ExpressionOrString = Union[str, Expression] - -ColumnOrString = Union[str, Column] - -PrimitiveType = Union[bool, float, int, str] - -LiteralType = PrimitiveType - -class FunctionBuilderCallable(Protocol): - def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: ... - -class UserDefinedFunctionCallable(Protocol): - def __call__(self, *_: ColumnOrString) -> UserDefinedFunction: ... From 6d146e4986e4183af67f61cdaacfde3c47845001 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 23 Nov 2022 13:06:59 +0800 Subject: [PATCH 2/3] address comments --- python/pyspark/sql/connect/column.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index fb42cddaf6906..36f38e0ded286 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -15,7 +15,7 @@ # limitations under the License. # import uuid -from typing import cast, TYPE_CHECKING, Callable, Any +from typing import cast, get_args, TYPE_CHECKING, Callable, Any import json import decimal @@ -23,8 +23,6 @@ import pyspark.sql.connect.proto as proto -primitive_types = (bool, float, int, str) - if TYPE_CHECKING: from pyspark.sql.connect.client import RemoteSparkSession import pyspark.sql.connect.proto as proto @@ -34,7 +32,9 @@ def _bin_op( name: str, doc: str = "binary function", reverse: bool = False ) -> Callable[["Column", Any], "Expression"]: def _(self: "Column", other: Any) -> "Expression": - if isinstance(other, primitive_types): + from pyspark.sql.connect._typing import PrimitiveType + + if isinstance(other, get_args(PrimitiveType)): other = LiteralExpression(other) if not reverse: return ScalarFunctionExpression(name, self, other) @@ -71,7 +71,9 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override] """Returns a binary expression with the current column as the left side and the other expression as the right side. """ - if isinstance(other, primitive_types): + from pyspark.sql.connect._typing import PrimitiveType + + if isinstance(other, get_args(PrimitiveType)): other = LiteralExpression(other) return ScalarFunctionExpression("==", self, other) From b11f15c95eb55ba2d2a7c1949f9094669361c3ec Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 23 Nov 2022 14:55:00 +0800 Subject: [PATCH 3/3] deal with typing_extensions --- python/pyspark/sql/connect/_typing.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index acaa51e4403e2..262d71fcea10a 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -15,7 +15,13 @@ # limitations under the License. # -from typing_extensions import Protocol +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + from typing import Union, Optional import datetime import decimal