Skip to content

Commit

Permalink
Implement Ternary copy_if_else (rapidsai#16114)
Browse files Browse the repository at this point in the history
A straightforward evaluation using `copy_if_else`.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - https://github.com/brandon-b-miller

URL: rapidsai#16114
  • Loading branch information
wence- authored Jun 28, 2024
1 parent c847b98 commit e35da6b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"GroupedRollingWindow",
"Cast",
"Agg",
"Ternary",
"BinOp",
]

Expand Down Expand Up @@ -1112,6 +1113,34 @@ def do_evaluate(
return self.op(child.evaluate(df, context=context, mapping=mapping))


class Ternary(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr, Expr, Expr]

def __init__(
self, dtype: plc.DataType, when: Expr, then: Expr, otherwise: Expr
) -> None:
super().__init__(dtype)
self.children = (when, then, otherwise)

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
when, then, otherwise = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
then_obj = then.obj_scalar if then.is_scalar else then.obj
otherwise_obj = otherwise.obj_scalar if otherwise.is_scalar else otherwise.obj
return Column(plc.copying.copy_if_else(then_obj, otherwise_obj, when.obj))


class BinOp(Expr):
__slots__ = ("op", "children")
_non_child = ("dtype", "op")
Expand Down
10 changes: 10 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,16 @@ def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Ex
)


@_translate_expr.register
def _(node: pl_expr.Ternary, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
return expr.Ternary(
dtype,
translate_expr(visitor, n=node.predicate),
translate_expr(visitor, n=node.truthy),
translate_expr(visitor, n=node.falsy),
)


@_translate_expr.register
def _(
node: pl_expr.BinaryExpr, visitor: NodeTraverser, dtype: plc.DataType
Expand Down
27 changes: 27 additions & 0 deletions python/cudf_polars/tests/expressions/test_when_then.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("then_scalar", [False, True])
@pytest.mark.parametrize("otherwise_scalar", [False, True])
@pytest.mark.parametrize("expr", [pl.col("c"), pl.col("c").is_not_null()])
def test_when_then(then_scalar, otherwise_scalar, expr):
ldf = pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [10, 13, 11, 15, 16, 11, 10],
"c": [None, True, False, False, True, True, False],
}
)

then = pl.lit(10) if then_scalar else pl.col("a")
otherwise = pl.lit(-2) if otherwise_scalar else pl.col("b")
q = ldf.select(pl.when(expr).then(then).otherwise(otherwise))
assert_gpu_result_equal(q)

0 comments on commit e35da6b

Please sign in to comment.