Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions python/pyspark/sql/classic/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pyspark.errors import PySparkAttributeError, PySparkTypeError, PySparkValueError
from pyspark.errors.utils import with_origin_to_class
from pyspark.sql.types import DataType
from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.utils import get_active_spark_context, enum_to_value

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand All @@ -52,7 +52,7 @@ def _create_column_from_literal(
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
return cast(JVMView, sc._jvm).functions.lit(literal)
return cast(JVMView, sc._jvm).functions.lit(enum_to_value(literal))


def _create_column_from_name(name: str) -> "JavaObject":
Expand Down Expand Up @@ -163,7 +163,7 @@ def _bin_op(
other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
) -> ParentColumn:
"""Create a method for given binary operator"""
jc = other._jc if isinstance(other, ParentColumn) else other
jc = other._jc if isinstance(other, ParentColumn) else enum_to_value(other)
njc = getattr(self._jc, name)(jc)
return Column(njc)

Expand Down Expand Up @@ -441,20 +441,23 @@ def endswith(
return _bin_op("endsWith", self, other)

def like(self: ParentColumn, other: str) -> ParentColumn:
njc = getattr(self._jc, "like")(other)
njc = getattr(self._jc, "like")(enum_to_value(other))
return Column(njc)

def rlike(self: ParentColumn, other: str) -> ParentColumn:
njc = getattr(self._jc, "rlike")(other)
njc = getattr(self._jc, "rlike")(enum_to_value(other))
return Column(njc)

def ilike(self: ParentColumn, other: str) -> ParentColumn:
njc = getattr(self._jc, "ilike")(other)
njc = getattr(self._jc, "ilike")(enum_to_value(other))
return Column(njc)

def substr(
self, startPos: Union[int, ParentColumn], length: Union[int, ParentColumn]
) -> ParentColumn:
startPos = enum_to_value(startPos)
length = enum_to_value(length)

if type(startPos) != type(length):
raise PySparkTypeError(
errorClass="NOT_SAME_TYPE",
Expand Down Expand Up @@ -586,12 +589,12 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn:
errorClass="NOT_COLUMN",
messageParameters={"arg_name": "condition", "arg_type": type(condition).__name__},
)
v = value._jc if isinstance(value, Column) else value
v = value._jc if isinstance(value, Column) else enum_to_value(value)
jc = self._jc.when(condition._jc, v)
return Column(jc)

def otherwise(self, value: Any) -> ParentColumn:
v = value._jc if isinstance(value, Column) else value
v = value._jc if isinstance(value, Column) else enum_to_value(value)
jc = self._jc.otherwise(v)
return Column(jc)

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pyspark.sql.column import Column as ParentColumn
from pyspark.errors import PySparkTypeError, PySparkAttributeError, PySparkValueError
from pyspark.sql.types import DataType
from pyspark.sql.utils import enum_to_value

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.expressions import (
Expand Down Expand Up @@ -69,6 +70,7 @@ def _bin_op(
other: Union[ParentColumn, "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
reverse: bool = False,
) -> ParentColumn:
other = enum_to_value(other)
if other is None or isinstance(
other,
(
Expand Down Expand Up @@ -351,6 +353,9 @@ def ilike(self: ParentColumn, other: str) -> ParentColumn:
def substr(
self, startPos: Union[int, ParentColumn], length: Union[int, ParentColumn]
) -> ParentColumn:
startPos = enum_to_value(startPos)
length = enum_to_value(length)

if type(startPos) != type(length):
raise PySparkTypeError(
errorClass="NOT_SAME_TYPE",
Expand All @@ -373,6 +378,7 @@ def substr(
return Column(UnresolvedFunction("substr", [self._expr, start_expr, length_expr]))

def __eq__(self, other: Any) -> ParentColumn: # type: ignore[override]
other = enum_to_value(other)
if other is None or isinstance(
other, (bool, float, int, str, datetime.datetime, datetime.date, decimal.Decimal)
):
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
)
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors.utils import current_origin
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.sql.utils import is_timestamp_ntz_preferred, enum_to_value

if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
Expand Down Expand Up @@ -231,6 +231,7 @@ def __init__(self, value: Any, dataType: DataType) -> None:
),
)

value = enum_to_value(value)
if isinstance(dataType, NullType):
assert value is None

Expand Down Expand Up @@ -295,6 +296,7 @@ def __init__(self, value: Any, dataType: DataType) -> None:

@classmethod
def _infer_type(cls, value: Any) -> DataType:
value = enum_to_value(value)
if value is None:
return NullType()
elif isinstance(value, (bytes, bytearray)):
Expand Down
52 changes: 44 additions & 8 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
ArrayType,
StringType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value

# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
# for code reuse.
Expand Down Expand Up @@ -448,6 +449,7 @@ def when(condition: Column, value: Any) -> Column:
messageParameters={"arg_name": "condition", "arg_type": type(condition).__name__},
)

value = _enum_to_value(value)
value_col = value if isinstance(value, Column) else lit(value)

return ConnectColumn(
Expand Down Expand Up @@ -576,8 +578,9 @@ def bround(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> C
if scale is None:
return _invoke_function_over_columns("bround", col)
else:
scale = _enum_to_value(scale)
scale = lit(scale) if isinstance(scale, int) else scale
return _invoke_function_over_columns("bround", col, scale)
return _invoke_function_over_columns("bround", col, scale) # type: ignore[arg-type]


bround.__doc__ = pysparkfuncs.bround.__doc__
Expand All @@ -594,8 +597,9 @@ def ceil(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Col
if scale is None:
return _invoke_function_over_columns("ceil", col)
else:
scale = _enum_to_value(scale)
scale = lit(scale) if isinstance(scale, int) else scale
return _invoke_function_over_columns("ceil", col, scale)
return _invoke_function_over_columns("ceil", col, scale) # type: ignore[arg-type]


ceil.__doc__ = pysparkfuncs.ceil.__doc__
Expand All @@ -605,8 +609,9 @@ def ceiling(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) ->
if scale is None:
return _invoke_function_over_columns("ceiling", col)
else:
scale = _enum_to_value(scale)
scale = lit(scale) if isinstance(scale, int) else scale
return _invoke_function_over_columns("ceiling", col, scale)
return _invoke_function_over_columns("ceiling", col, scale) # type: ignore[arg-type]


ceiling.__doc__ = pysparkfuncs.ceiling.__doc__
Expand Down Expand Up @@ -686,8 +691,9 @@ def floor(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Co
if scale is None:
return _invoke_function_over_columns("floor", col)
else:
scale = _enum_to_value(scale)
scale = lit(scale) if isinstance(scale, int) else scale
return _invoke_function_over_columns("floor", col, scale)
return _invoke_function_over_columns("floor", col, scale) # type: ignore[arg-type]


floor.__doc__ = pysparkfuncs.floor.__doc__
Expand Down Expand Up @@ -784,6 +790,7 @@ def width_bucket(
max: "ColumnOrName",
numBucket: Union["ColumnOrName", int],
) -> Column:
numBucket = _enum_to_value(numBucket)
numBucket = lit(numBucket) if isinstance(numBucket, int) else numBucket
return _invoke_function_over_columns("width_bucket", v, min, max, numBucket)

Expand Down Expand Up @@ -819,8 +826,9 @@ def round(col: "ColumnOrName", scale: Optional[Union[Column, int]] = None) -> Co
if scale is None:
return _invoke_function_over_columns("round", col)
else:
scale = _enum_to_value(scale)
scale = lit(scale) if isinstance(scale, int) else scale
return _invoke_function_over_columns("round", col, scale)
return _invoke_function_over_columns("round", col, scale) # type: ignore[arg-type]


round.__doc__ = pysparkfuncs.round.__doc__
Expand Down Expand Up @@ -1487,8 +1495,11 @@ def any_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]] =
if ignoreNulls is None:
return _invoke_function_over_columns("any_value", col)
else:
ignoreNulls = _enum_to_value(ignoreNulls)
ignoreNulls = lit(ignoreNulls) if isinstance(ignoreNulls, bool) else ignoreNulls
return _invoke_function_over_columns("any_value", col, ignoreNulls)
return _invoke_function_over_columns(
"any_value", col, ignoreNulls # type: ignore[arg-type]
)


any_value.__doc__ = pysparkfuncs.any_value.__doc__
Expand All @@ -1498,8 +1509,11 @@ def first_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]]
if ignoreNulls is None:
return _invoke_function_over_columns("first_value", col)
else:
ignoreNulls = _enum_to_value(ignoreNulls)
ignoreNulls = lit(ignoreNulls) if isinstance(ignoreNulls, bool) else ignoreNulls
return _invoke_function_over_columns("first_value", col, ignoreNulls)
return _invoke_function_over_columns(
"first_value", col, ignoreNulls # type: ignore[arg-type]
)


first_value.__doc__ = pysparkfuncs.first_value.__doc__
Expand All @@ -1509,8 +1523,11 @@ def last_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]] =
if ignoreNulls is None:
return _invoke_function_over_columns("last_value", col)
else:
ignoreNulls = _enum_to_value(ignoreNulls)
ignoreNulls = lit(ignoreNulls) if isinstance(ignoreNulls, bool) else ignoreNulls
return _invoke_function_over_columns("last_value", col, ignoreNulls)
return _invoke_function_over_columns(
"last_value", col, ignoreNulls # type: ignore[arg-type]
)


last_value.__doc__ = pysparkfuncs.last_value.__doc__
Expand Down Expand Up @@ -1628,6 +1645,7 @@ def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:


def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column:
pos = _enum_to_value(pos)
_pos = lit(pos) if isinstance(pos, int) else _to_col(pos)
return _invoke_function("array_insert", _to_col(arr), _pos, lit(value))

Expand Down Expand Up @@ -1711,6 +1729,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column:


def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column:
count = _enum_to_value(count)
_count = lit(count) if isinstance(count, int) else _to_col(count)
return _invoke_function("array_repeat", _to_col(col), _count)

Expand Down Expand Up @@ -1901,6 +1920,7 @@ def from_xml(


def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
index = _enum_to_value(index)
index = lit(index) if isinstance(index, int) else index

return _invoke_function_over_columns("get", col, index)
Expand Down Expand Up @@ -2124,6 +2144,7 @@ def sequence(


def schema_of_csv(csv: Union[str, Column], options: Optional[Dict[str, str]] = None) -> Column:
csv = _enum_to_value(csv)
if not isinstance(csv, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand All @@ -2140,6 +2161,7 @@ def schema_of_csv(csv: Union[str, Column], options: Optional[Dict[str, str]] = N


def schema_of_json(json: Union[str, Column], options: Optional[Dict[str, str]] = None) -> Column:
json = _enum_to_value(json)
if not isinstance(json, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand All @@ -2156,6 +2178,7 @@ def schema_of_json(json: Union[str, Column], options: Optional[Dict[str, str]] =


def schema_of_xml(xml: Union[str, Column], options: Optional[Dict[str, str]] = None) -> Column:
xml = _enum_to_value(xml)
if not isinstance(xml, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand Down Expand Up @@ -2192,6 +2215,7 @@ def size(col: "ColumnOrName") -> Column:
def slice(
col: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]
) -> Column:
start = _enum_to_value(start)
if isinstance(start, (Column, str)):
_start = start
elif isinstance(start, int):
Expand All @@ -2202,6 +2226,7 @@ def slice(
messageParameters={"arg_name": "start", "arg_type": type(start).__name__},
)

length = _enum_to_value(length)
if isinstance(length, (Column, str)):
_length = length
elif isinstance(length, int):
Expand Down Expand Up @@ -2415,11 +2440,13 @@ def overlay(
pos: Union["ColumnOrName", int],
len: Union["ColumnOrName", int] = -1,
) -> Column:
pos = _enum_to_value(pos)
if not isinstance(pos, (int, str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_INT_OR_STR",
messageParameters={"arg_name": "pos", "arg_type": type(pos).__name__},
)
len = _enum_to_value(len)
if len is not None and not isinstance(len, (int, str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_INT_OR_STR",
Expand Down Expand Up @@ -2499,6 +2526,7 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column:


def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column:
n = _enum_to_value(n)
n = lit(n) if isinstance(n, int) else n
return _invoke_function("repeat", _to_col(col), _to_col(n))

Expand All @@ -2511,6 +2539,7 @@ def split(
pattern: Union[Column, str],
limit: Union["ColumnOrName", int] = -1,
) -> Column:
limit = _enum_to_value(limit)
limit = lit(limit) if isinstance(limit, int) else _to_col(limit)
return _invoke_function("split", _to_col(str), lit(pattern), limit)

Expand Down Expand Up @@ -3097,6 +3126,7 @@ def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName")


def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:
days = _enum_to_value(days)
days = lit(days) if isinstance(days, int) else days
return _invoke_function_over_columns("date_add", start, days)

Expand All @@ -3105,6 +3135,7 @@ def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:


def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:
days = _enum_to_value(days)
days = lit(days) if isinstance(days, int) else days
return _invoke_function_over_columns("dateadd", start, days)

Expand All @@ -3113,6 +3144,7 @@ def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:


def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:
days = _enum_to_value(days)
days = lit(days) if isinstance(days, int) else days
return _invoke_function_over_columns("date_sub", start, days)

Expand Down Expand Up @@ -3142,6 +3174,7 @@ def date_from_unix_date(days: "ColumnOrName") -> Column:


def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column:
months = _enum_to_value(months)
months = lit(months) if isinstance(months, int) else months
return _invoke_function_over_columns("add_months", start, months)

Expand Down Expand Up @@ -3449,6 +3482,7 @@ def window_time(


def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) -> Column:
gapDuration = _enum_to_value(gapDuration)
if gapDuration is None or not isinstance(gapDuration, (Column, str)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand Down Expand Up @@ -3729,6 +3763,7 @@ def session_user() -> Column:


def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None) -> Column:
errMsg = _enum_to_value(errMsg)
if errMsg is None:
return _invoke_function_over_columns("assert_true", col)
if not isinstance(errMsg, (str, Column)):
Expand All @@ -3743,6 +3778,7 @@ def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None


def raise_error(errMsg: Union[Column, str]) -> Column:
errMsg = _enum_to_value(errMsg)
if not isinstance(errMsg, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand Down
Loading