From 035b3ae54f37ad13a9de891e949014d312eb49f4 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 1 Nov 2022 08:18:03 +0100 Subject: [PATCH 1/2] Supporting most literal types from python --- python/pyspark/sql/connect/_typing.py | 2 + python/pyspark/sql/connect/column.py | 56 ++++++++++++- python/pyspark/sql/connect/functions.py | 5 +- .../test_connect_column_expressions.py | 84 ++++++++++++++++++- 4 files changed, 139 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 5cd14111badac..4e69b2e4aa5ef 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -15,5 +15,7 @@ # limitations under the License. # from typing import Union +from datetime import date, time, datetime PrimitiveType = Union[str, int, bool, float] +LiteralType = Union[PrimitiveType, Union[date, time, datetime]] diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 126c45d6b4a8a..92f4564fdc00d 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import uuid from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any +import decimal +import datetime import pyspark.sql.connect.proto as proto from pyspark.sql.connect._typing import PrimitiveType @@ -87,7 +89,7 @@ class LiteralExpression(Expression): The Python types are converted best effort into the relevant proto types. On the Spark Connect server side, the proto types are converted to the Catalyst equivalents.""" - def __init__(self, value: PrimitiveType) -> None: + def __init__(self, value: Any) -> None: super().__init__() self._value = value @@ -99,11 +101,59 @@ def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression" value_type = type(self._value) exp = proto.Expression() if value_type is int: - exp.literal.i32 = cast(int, self._value) + exp.literal.i64 = cast(int, self._value) + elif value_type is bool: + exp.literal.boolean = cast(bool, self._value) elif value_type is str: exp.literal.string = cast(str, self._value) elif value_type is float: exp.literal.fp64 = cast(float, self._value) + elif value_type is decimal.Decimal: + d_v = cast(decimal.Decimal, self._value) + v_tuple = d_v.as_tuple() + exp.literal.decimal.scale = abs(v_tuple.exponent) + exp.literal.decimal.precision = len(v_tuple.digits) - abs(v_tuple.exponent) + # Two complement yeah... + raise ValueError("cannnt....") + elif value_type is bytes: + exp.literal.binary = self._value + elif value_type is datetime.datetime: + # Microseconds since epoch. + dt = cast(datetime.datetime, self._value) + v = dt - datetime.datetime(1970, 1, 1, 0, 0, 0, 0) + exp.literal.timestamp = int(v / datetime.timedelta(microseconds=1)) + elif value_type is datetime.time: + # Nanoseconds of the day. + tv = cast(datetime.time, self._value) + offset = (tv.second + tv.minute * 60 + tv.hour * 3600) * 1000 + tv.microsecond + exp.literal.time = int(offset * 1000) + elif value_type is datetime.date: + # Days since epoch. + days_since_epoch = (cast(datetime.date, self._value) - datetime.date(1970, 1, 1)).days + exp.literal.date = days_since_epoch + elif value_type is uuid.UUID: + exp.literal.uuid = cast(uuid.UUID, self._value).bytes + elif value_type is list: + lv = cast(list, self._value) + for k in lv: + if type(k) is LiteralExpression: + exp.literal.list.values.append(k.to_plan(session).literal) + else: + exp.literal.list.values.append(LiteralExpression(k).to_plan(session).literal) + elif value_type is dict: + mv = cast(dict, self._value) + for k in mv: + kv = proto.Expression.Literal.Map.KeyValue() + if type(k) is LiteralExpression: + kv.key.CopyFrom(k.to_plan(session).literal) + else: + kv.key.CopyFrom(LiteralExpression(k).to_plan(session).literal) + + if type(mv[k]) is LiteralExpression: + kv.value.CopyFrom(mv[k].to_plan(session).literal) + else: + kv.value.CopyFrom(LiteralExpression(mv[k]).to_plan(session).literal) + exp.literal.map.key_values.append(kv) else: raise ValueError(f"Could not convert literal for type {type(self._value)}") diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 4fe57d9228377..880096da45983 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -15,7 +15,8 @@ # limitations under the License. # from pyspark.sql.connect.column import ColumnRef, LiteralExpression -from pyspark.sql.connect.column import PrimitiveType + +from typing import Any # TODO(SPARK-40538) Add support for the missing PySpark functions. @@ -24,5 +25,5 @@ def col(x: str) -> ColumnRef: return ColumnRef(x) -def lit(x: PrimitiveType) -> LiteralExpression: +def lit(x: Any) -> LiteralExpression: return LiteralExpression(x) diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 790a987e88090..fa2850ddd1413 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -14,9 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import uuid from typing import cast import unittest +import decimal +import datetime + from pyspark.testing.connectutils import PlanOnlyTestFixture from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message @@ -49,6 +52,33 @@ def test_simple_column_expressions(self): self.assertEqual(cp1, cp2) self.assertEqual(cp2, cp3) + def test_binary_literal(self): + val = b"binary\0\0asas" + bin_lit = fun.lit(val) + bin_lit_p = bin_lit.to_plan(None) + self.assertEqual(bin_lit_p.literal.binary, val) + + def test_map_literal(self): + val = {"this": "is", 12: [12, 32, 43]} + map_lit = fun.lit(val) + map_lit_p = map_lit.to_plan(None) + self.assertEqual(2, len(map_lit_p.literal.map.key_values)) + self.assertEqual("this", map_lit_p.literal.map.key_values[0].key.string) + self.assertEqual(12, map_lit_p.literal.map.key_values[1].key.i64) + + val = {"this": fun.lit("is"), 12: [12, 32, 43]} + map_lit = fun.lit(val) + map_lit_p = map_lit.to_plan(None) + self.assertEqual(2, len(map_lit_p.literal.map.key_values)) + self.assertEqual("is", map_lit_p.literal.map.key_values[0].value.string) + + def test_uuid_literal(self): + val = uuid.uuid4() + lit = fun.lit(val) + lit_p = lit.to_plan(None) + + self.assertIsNotNone(lit_p) + def test_column_literals(self): df = c.DataFrame.withPlan(p.Read("table")) lit_df = df.select(fun.lit(10)) @@ -56,7 +86,55 @@ def test_column_literals(self): self.assertIsNotNone(fun.lit(10).to_plan(None)) plan = fun.lit(10).to_plan(None) - self.assertIs(plan.literal.i32, 10) + self.assertIs(plan.literal.i64, 10) + + def test_numeric_literal_types(self): + int_lit = fun.lit(10) + float_lit = fun.lit(10.1) + decimal_lit = fun.lit(decimal.Decimal(99)) + + # Decimal is not supported yet. + with self.assertRaises(ValueError): + self.assertIsNotNone(decimal_lit.to_plan(None)) + + self.assertIsNotNone(int_lit.to_plan(None)) + self.assertIsNotNone(float_lit.to_plan(None)) + + def test_datetime_literal_types(self): + """Test the different timestamp, date, and time types.""" + datetime_lit = fun.lit(datetime.datetime.now()) + + p = datetime_lit.to_plan(None) + self.assertIsNotNone(datetime_lit.to_plan(None)) + self.assertGreater(p.literal.timestamp, 10000000000000) + + date_lit = fun.lit(datetime.date.today()) + time_lit = fun.lit(datetime.time()) + + self.assertIsNotNone(date_lit.to_plan(None)) + self.assertIsNotNone(time_lit.to_plan(None)) + + def test_list_to_literal(self): + """Test conversion of lists to literals""" + empty_list = [] + single_type = [1, 2, 3, 4] + multi_type = ["ooo", 1, "asas", 2.3] + + empty_list_lit = fun.lit(empty_list) + single_type_lit = fun.lit(single_type) + multi_type_lit = fun.lit(multi_type) + + p = empty_list_lit.to_plan(None) + self.assertIsNotNone(p) + + p = single_type_lit.to_plan(None) + self.assertIsNotNone(p) + + p = multi_type_lit.to_plan(None) + self.assertIsNotNone(p) + + lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None) + self.assertIsNotNone(lit_list_plan) def test_column_expressions(self): """Test a more complex combination of expressions and their translation into @@ -76,7 +154,7 @@ def test_column_expressions(self): lit_fun = expr_plan.unresolved_function.arguments[1] self.assertIsInstance(lit_fun, ProtoExpression) self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal) - self.assertEqual(lit_fun.literal.i32, 10) + self.assertEqual(lit_fun.literal.i64, 10) mod_fun = expr_plan.unresolved_function.arguments[0] self.assertIsInstance(mod_fun, ProtoExpression) From 44afc3c8fb9b91fc6ad7f9a40b74a74f71596f33 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Thu, 3 Nov 2022 14:14:37 +0100 Subject: [PATCH 2/2] removing UUID support --- python/pyspark/sql/connect/column.py | 4 ++-- .../sql/tests/connect/test_connect_column_expressions.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 92f4564fdc00d..42466fa169922 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -114,7 +114,7 @@ def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression" exp.literal.decimal.scale = abs(v_tuple.exponent) exp.literal.decimal.precision = len(v_tuple.digits) - abs(v_tuple.exponent) # Two complement yeah... - raise ValueError("cannnt....") + raise ValueError("Python Decimal not supported.") elif value_type is bytes: exp.literal.binary = self._value elif value_type is datetime.datetime: @@ -132,7 +132,7 @@ def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression" days_since_epoch = (cast(datetime.date, self._value) - datetime.date(1970, 1, 1)).days exp.literal.date = days_since_epoch elif value_type is uuid.UUID: - exp.literal.uuid = cast(uuid.UUID, self._value).bytes + raise ValueError("Python UUID type not supported.") elif value_type is list: lv = cast(list, self._value) for k in lv: diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index fa2850ddd1413..8773fe4aceba3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -75,9 +75,8 @@ def test_map_literal(self): def test_uuid_literal(self): val = uuid.uuid4() lit = fun.lit(val) - lit_p = lit.to_plan(None) - - self.assertIsNotNone(lit_p) + with self.assertRaises(ValueError): + lit.to_plan(None) def test_column_literals(self): df = c.DataFrame.withPlan(p.Read("table"))