Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
# limitations under the License.
#
from typing import Union
from datetime import date, time, datetime

PrimitiveType = Union[str, int, bool, float]
LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
56 changes: 53 additions & 3 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import uuid
from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any

import decimal
import datetime

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect._typing import PrimitiveType
Expand Down Expand Up @@ -87,7 +89,7 @@ class LiteralExpression(Expression):
The Python types are converted best effort into the relevant proto types. On the Spark Connect
server side, the proto types are converted to the Catalyst equivalents."""

def __init__(self, value: PrimitiveType) -> None:
def __init__(self, value: Any) -> None:
super().__init__()
self._value = value

Expand All @@ -99,11 +101,59 @@ def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression"
value_type = type(self._value)
exp = proto.Expression()
if value_type is int:
exp.literal.i32 = cast(int, self._value)
exp.literal.i64 = cast(int, self._value)
elif value_type is bool:
exp.literal.boolean = cast(bool, self._value)
elif value_type is str:
exp.literal.string = cast(str, self._value)
elif value_type is float:
exp.literal.fp64 = cast(float, self._value)
elif value_type is decimal.Decimal:
d_v = cast(decimal.Decimal, self._value)
v_tuple = d_v.as_tuple()
exp.literal.decimal.scale = abs(v_tuple.exponent)
exp.literal.decimal.precision = len(v_tuple.digits) - abs(v_tuple.exponent)
# Two complement yeah...
raise ValueError("Python Decimal not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove if this is not implemented yet?

elif value_type is bytes:
exp.literal.binary = self._value
elif value_type is datetime.datetime:
# Microseconds since epoch.
dt = cast(datetime.datetime, self._value)
v = dt - datetime.datetime(1970, 1, 1, 0, 0, 0, 0)
exp.literal.timestamp = int(v / datetime.timedelta(microseconds=1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif value_type is datetime.time:
# Nanoseconds of the day.
tv = cast(datetime.time, self._value)
offset = (tv.second + tv.minute * 60 + tv.hour * 3600) * 1000 + tv.microsecond
exp.literal.time = int(offset * 1000)
elif value_type is datetime.date:
# Days since epoch.
days_since_epoch = (cast(datetime.date, self._value) - datetime.date(1970, 1, 1)).days
exp.literal.date = days_since_epoch
elif value_type is uuid.UUID:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could remove elif so else branch throw the exception?

raise ValueError("Python UUID type not supported.")
elif value_type is list:
lv = cast(list, self._value)
for k in lv:
if type(k) is LiteralExpression:
exp.literal.list.values.append(k.to_plan(session).literal)
else:
exp.literal.list.values.append(LiteralExpression(k).to_plan(session).literal)
elif value_type is dict:
mv = cast(dict, self._value)
for k in mv:
kv = proto.Expression.Literal.Map.KeyValue()
if type(k) is LiteralExpression:
kv.key.CopyFrom(k.to_plan(session).literal)
else:
kv.key.CopyFrom(LiteralExpression(k).to_plan(session).literal)

if type(mv[k]) is LiteralExpression:
kv.value.CopyFrom(mv[k].to_plan(session).literal)
else:
kv.value.CopyFrom(LiteralExpression(mv[k]).to_plan(session).literal)
exp.literal.map.key_values.append(kv)
else:
raise ValueError(f"Could not convert literal for type {type(self._value)}")

Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.
#
from pyspark.sql.connect.column import ColumnRef, LiteralExpression
from pyspark.sql.connect.column import PrimitiveType

from typing import Any

# TODO(SPARK-40538) Add support for the missing PySpark functions.

Expand All @@ -24,5 +25,5 @@ def col(x: str) -> ColumnRef:
return ColumnRef(x)


def lit(x: PrimitiveType) -> LiteralExpression:
def lit(x: Any) -> LiteralExpression:
return LiteralExpression(x)
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import uuid
from typing import cast
import unittest
import decimal
import datetime

from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message

Expand Down Expand Up @@ -49,14 +52,88 @@ def test_simple_column_expressions(self):
self.assertEqual(cp1, cp2)
self.assertEqual(cp2, cp3)

def test_binary_literal(self):
val = b"binary\0\0asas"
bin_lit = fun.lit(val)
bin_lit_p = bin_lit.to_plan(None)
self.assertEqual(bin_lit_p.literal.binary, val)

def test_map_literal(self):
val = {"this": "is", 12: [12, 32, 43]}
map_lit = fun.lit(val)
map_lit_p = map_lit.to_plan(None)
self.assertEqual(2, len(map_lit_p.literal.map.key_values))
self.assertEqual("this", map_lit_p.literal.map.key_values[0].key.string)
self.assertEqual(12, map_lit_p.literal.map.key_values[1].key.i64)

val = {"this": fun.lit("is"), 12: [12, 32, 43]}
map_lit = fun.lit(val)
map_lit_p = map_lit.to_plan(None)
self.assertEqual(2, len(map_lit_p.literal.map.key_values))
self.assertEqual("is", map_lit_p.literal.map.key_values[0].value.string)

def test_uuid_literal(self):
val = uuid.uuid4()
lit = fun.lit(val)
with self.assertRaises(ValueError):
lit.to_plan(None)

def test_column_literals(self):
df = c.DataFrame.withPlan(p.Read("table"))
lit_df = df.select(fun.lit(10))
self.assertIsNotNone(lit_df._plan.to_proto(None))

self.assertIsNotNone(fun.lit(10).to_plan(None))
plan = fun.lit(10).to_plan(None)
self.assertIs(plan.literal.i32, 10)
self.assertIs(plan.literal.i64, 10)

def test_numeric_literal_types(self):
int_lit = fun.lit(10)
float_lit = fun.lit(10.1)
decimal_lit = fun.lit(decimal.Decimal(99))

# Decimal is not supported yet.
with self.assertRaises(ValueError):
self.assertIsNotNone(decimal_lit.to_plan(None))

self.assertIsNotNone(int_lit.to_plan(None))
self.assertIsNotNone(float_lit.to_plan(None))

def test_datetime_literal_types(self):
"""Test the different timestamp, date, and time types."""
datetime_lit = fun.lit(datetime.datetime.now())

p = datetime_lit.to_plan(None)
self.assertIsNotNone(datetime_lit.to_plan(None))
self.assertGreater(p.literal.timestamp, 10000000000000)

date_lit = fun.lit(datetime.date.today())
time_lit = fun.lit(datetime.time())

self.assertIsNotNone(date_lit.to_plan(None))
self.assertIsNotNone(time_lit.to_plan(None))

def test_list_to_literal(self):
"""Test conversion of lists to literals"""
empty_list = []
single_type = [1, 2, 3, 4]
multi_type = ["ooo", 1, "asas", 2.3]

empty_list_lit = fun.lit(empty_list)
single_type_lit = fun.lit(single_type)
multi_type_lit = fun.lit(multi_type)

p = empty_list_lit.to_plan(None)
self.assertIsNotNone(p)

p = single_type_lit.to_plan(None)
self.assertIsNotNone(p)

p = multi_type_lit.to_plan(None)
self.assertIsNotNone(p)

lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
self.assertIsNotNone(lit_list_plan)

def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
Expand All @@ -76,7 +153,7 @@ def test_column_expressions(self):
lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, ProtoExpression)
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
self.assertEqual(lit_fun.literal.i32, 10)
self.assertEqual(lit_fun.literal.i64, 10)

mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, ProtoExpression)
Expand Down