Skip to content

Commit

Permalink
Feature/expose when function (#836)
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Aug 30, 2024
1 parent 69ed7fe commit 003eea8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
"var",
"var_pop",
"var_samp",
"when",
"window",
]

Expand Down Expand Up @@ -364,6 +365,16 @@ def case(expr: Expr) -> CaseBuilder:
return CaseBuilder(f.case(expr.expr))


def when(when: Expr, then: Expr) -> CaseBuilder:
"""Create a case expression that has no base expression.
Create a :py:class:`~datafusion.expr.CaseBuilder` to match cases for the
expression ``expr``. See :py:class:`~datafusion.expr.CaseBuilder` for
detailed usage.
"""
return CaseBuilder(f.when(when.expr, then.expr))


def window(
name: str,
args: list[Expr],
Expand Down
19 changes: 19 additions & 0 deletions python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,25 @@ def test_case(df):
assert result.column(2) == pa.array(["Hola", "Mundo", None])


def test_when_with_no_base(df):
df.show()
df = df.select(
column("b"),
f.when(column("b") > literal(5), literal("too big"))
.when(column("b") < literal(5), literal("too small"))
.otherwise(literal("just right"))
.alias("goldilocks"),
f.when(column("a") == literal("Hello"), column("a")).end().alias("greeting"),
)
df.show()

result = df.collect()
result = result[0]
assert result.column(0) == pa.array([4, 5, 6])
assert result.column(1) == pa.array(["too small", "just right", "too big"])
assert result.column(2) == pa.array(["Hello", None, None])


def test_regr_funcs_sql(df):
# test case base on
# https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
Expand Down
9 changes: 9 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,14 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
})
}

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
#[pyfunction]
fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
Ok(PyCaseBuilder {
case_builder: datafusion_expr::when(when.expr, then.expr),
})
}

/// Helper function to find the appropriate window function.
///
/// Search procedure:
Expand Down Expand Up @@ -910,6 +918,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(char_length))?;
m.add_wrapped(wrap_pyfunction!(coalesce))?;
m.add_wrapped(wrap_pyfunction!(case))?;
m.add_wrapped(wrap_pyfunction!(when))?;
m.add_wrapped(wrap_pyfunction!(col))?;
m.add_wrapped(wrap_pyfunction!(concat_ws))?;
m.add_wrapped(wrap_pyfunction!(concat))?;
Expand Down

0 comments on commit 003eea8

Please sign in to comment.