Skip to content

Commit

Permalink
feat: make cast accept built-in Python types
Browse files Browse the repository at this point in the history
closes #753
  • Loading branch information
mesejo committed Sep 6, 2024
1 parent fe0738a commit 041de33
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
21 changes: 19 additions & 2 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
functions as functions_internal,
)
from datafusion.common import NullTreatment, RexType, DataTypeMap
from typing import Any, Optional
from typing import Any, Optional, Type
import pyarrow as pa

# The following are imported from the internal representation. We may choose to
Expand Down Expand Up @@ -372,8 +372,25 @@ def is_not_null(self) -> Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())

def cast(self, to: pa.DataType[Any]) -> Expr:
_to_pyarrow_types = {
float: pa.float64(),
int: pa.int64(),
str: pa.string(),
bool: pa.bool_(),
}

def cast(
self, to: pa.DataType[Any] | Type[float] | Type[int] | Type[str] | Type[bool]
) -> Expr:
"""Cast to a new data type."""
if not isinstance(to, pa.DataType):
try:
to = self._to_pyarrow_types[to]
except KeyError:
raise TypeError(
"Expected instance of pyarrow.DataType or builtins.type"
)

return Expr(self.expr.cast(to))

def rex_type(self) -> RexType:
Expand Down
22 changes: 21 additions & 1 deletion python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def df():
datetime(2020, 7, 2),
]
),
pa.array([False, True, True]),
],
names=["a", "b", "c", "d"],
names=["a", "b", "c", "d", "e"],
)
return ctx.create_dataframe([[batch]])

Expand Down Expand Up @@ -978,3 +979,22 @@ def test_binary_string_functions(df):
assert pa.array(result.column(1)).cast(pa.string()) == pa.array(
["Hello", "World", "!"]
)


@pytest.mark.parametrize(
"python_datatype, name, expected",
[
pytest.param(bool, "e", pa.bool_(), id="bool"),
pytest.param(int, "b", pa.int64(), id="int"),
pytest.param(float, "b", pa.float64(), id="float"),
pytest.param(str, "b", pa.string(), id="str"),
],
)
def test_cast(df, python_datatype, name: str, expected):
df = df.select(
column(name).cast(python_datatype).alias("actual"),
column(name).cast(expected).alias("expected"),
)
result = df.collect()
result = result[0]
assert result.column(0) == result.column(1)

0 comments on commit 041de33

Please sign in to comment.