-
Notifications
You must be signed in to change notification settings - Fork 921
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement handlers for series literal in cudf-polars (#16113)
A query plan can contain a "literal" polars Series. Often, for example, when calling a contains-like function. To translate these, introduce a new `LiteralColumn` node to capture the concept and add an evaluation rule (converting from arrow). Since list-dtype Series need the same casting treatment as in dataframe scan case, factor the casting out into a utility, and take the opportunity to handled casting of nested lists correctly. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Thomas Li (https://github.com/lithomas1) - Vyas Ramasubramani (https://github.com/vyasr) URL: #16113
- Loading branch information
Showing
6 changed files
with
239 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
import polars as pl | ||
|
||
from cudf_polars.testing.asserts import ( | ||
assert_gpu_result_equal, | ||
assert_ir_translation_raises, | ||
) | ||
from cudf_polars.utils import dtypes | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
None, | ||
pl.Int8(), | ||
pl.Int16(), | ||
pl.Int32(), | ||
pl.Int64(), | ||
pl.UInt8(), | ||
pl.UInt16(), | ||
pl.UInt32(), | ||
pl.UInt64(), | ||
] | ||
) | ||
def integer(request): | ||
return pl.lit(10, dtype=request.param) | ||
|
||
|
||
@pytest.fixture(params=[None, pl.Float32(), pl.Float64()]) | ||
def float(request): | ||
return pl.lit(1.0, dtype=request.param) | ||
|
||
|
||
def test_numeric_literal(integer, float): | ||
df = pl.LazyFrame({}) | ||
|
||
q = df.select(integer=integer, float_=float, sum_=integer + float) | ||
|
||
assert_gpu_result_equal(q) | ||
|
||
|
||
@pytest.fixture( | ||
params=[pl.Date(), pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")] | ||
) | ||
def timestamp(request): | ||
return pl.lit(10_000, dtype=request.param) | ||
|
||
|
||
@pytest.fixture(params=[pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")]) | ||
def timedelta(request): | ||
return pl.lit(9_000, dtype=request.param) | ||
|
||
|
||
def test_timelike_literal(timestamp, timedelta): | ||
df = pl.LazyFrame({}) | ||
|
||
q = df.select( | ||
time=timestamp, | ||
delta=timedelta, | ||
adjusted=timestamp + timedelta, | ||
two_delta=timedelta + timedelta, | ||
) | ||
schema = q.collect_schema() | ||
time_type = schema["time"] | ||
delta_type = schema["delta"] | ||
if dtypes.have_compatible_resolution( | ||
dtypes.from_polars(time_type).id(), dtypes.from_polars(delta_type).id() | ||
): | ||
assert_gpu_result_equal(q) | ||
else: | ||
assert_ir_translation_raises(q, NotImplementedError) | ||
|
||
|
||
def test_select_literal_series(): | ||
df = pl.LazyFrame({}) | ||
|
||
q = df.select( | ||
a=pl.Series(["a", "b", "c"], dtype=pl.String()), | ||
b=pl.Series([[1, 2], [3], None], dtype=pl.List(pl.UInt16())), | ||
c=pl.Series([[[1]], [], [[1, 2, 3, 4]]], dtype=pl.List(pl.List(pl.Float32()))), | ||
) | ||
|
||
assert_gpu_result_equal(q) | ||
|
||
|
||
@pytest.mark.parametrize("expr", [pl.lit(None), pl.lit(10, dtype=pl.Decimal())]) | ||
def test_unsupported_literal_raises(expr): | ||
df = pl.LazyFrame({}) | ||
|
||
q = df.select(expr) | ||
|
||
assert_ir_translation_raises(q, NotImplementedError) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters