Skip to content

Commit

Permalink
Make first_value and last_value identical in the interface
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Aug 31, 2024
1 parent 0fc0895 commit 616a748
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 68 deletions.
30 changes: 24 additions & 6 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,29 +1699,47 @@ def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
def first_value(
arg: Expr,
distinct: bool = False,
filter: bool = None,
order_by: Expr | None = None,
null_treatment: common.NullTreatment | None = None,
filter: Optional[bool] = None,
order_by: Optional[list[Expr]] = None,
null_treatment: Optional[common.NullTreatment] = None,
) -> Expr:
"""Returns the first value in a group of values."""
order_by_cols = [e.expr for e in order_by] if order_by is not None else None

return Expr(
f.first_value(
arg.expr,
distinct=distinct,
filter=filter,
order_by=order_by,
order_by=order_by_cols,
null_treatment=null_treatment,
)
)


def last_value(arg: Expr) -> Expr:
def last_value(
arg: Expr,
distinct: bool = False,
filter: Optional[bool] = None,
order_by: Optional[list[Expr]] = None,
null_treatment: Optional[common.NullTreatment] = None,
) -> Expr:
"""Returns the last value in a group of values.
To set parameters on this expression, use ``.order_by()``, ``.distinct()``,
``.filter()``, or ``.null_treatment()``.
"""
return Expr(f.last_value(arg.expr))
order_by_cols = [e.expr for e in order_by] if order_by is not None else None

return Expr(
f.last_value(
arg.expr,
distinct=distinct,
filter=filter,
order_by=order_by_cols,
null_treatment=null_treatment,
)
)


def bit_and(arg: Expr, distinct: bool = False) -> Expr:
Expand Down
152 changes: 99 additions & 53 deletions python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,45 +567,86 @@ def test_array_function_obj_tests(stmt, py_expr):
assert a == b


@pytest.mark.parametrize("function, expected_result", [
(f.ascii(column("a")), pa.array([72, 87, 33], type=pa.int32())), # H = 72; W = 87; ! = 33
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
(f.concat_ws("-", column("a"), literal("test")), pa.array(["Hello-test", "World-test", "!-test"])),
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
(f.md5(column("a")), pa.array([
"8b1a9953c4611296a827abf8c47804d7",
"f5a7924e621e84c9280a9a27e1bcb7f6",
"9033e0e305f247c0c3c80d0c7848c8b3",
])),
(f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.repeat(column("a"), literal(2)), pa.array(["HelloHello", "WorldWorld", "!!"])),
(f.replace(column("a"), literal("l"), literal("?")), pa.array(["He??o", "Wor?d", "!"])),
(f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])),
(f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])),
(f.rpad(column("a"), literal(8)), pa.array(["Hello ", "World ", "! "])),
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
(f.split_part(column("a"), literal("l"), literal(1)), pa.array(["He", "Wor", "!"])),
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
(f.translate(column("a"), literal("or"), literal("ld")), pa.array(["Helll", "Wldld", "!"])),
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
(f.overlay(column("a"), literal("--"), literal(2)), pa.array(["H--lo", "W--ld", "--"])),
(f.regexp_like(column("a"), literal("(ell|orl)")), pa.array([True, True, False])),
(f.regexp_match(column("a"), literal("(ell|orl)")), pa.array([["ell"], ["orl"], None])),
(f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), pa.array(["H-o", "W-d", "!"])),
])
@pytest.mark.parametrize(
"function, expected_result",
[
(
f.ascii(column("a")),
pa.array([72, 87, 33], type=pa.int32()),
), # H = 72; W = 87; ! = 33
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
(
f.concat_ws("-", column("a"), literal("test")),
pa.array(["Hello-test", "World-test", "!-test"]),
),
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
(
f.md5(column("a")),
pa.array(
[
"8b1a9953c4611296a827abf8c47804d7",
"f5a7924e621e84c9280a9a27e1bcb7f6",
"9033e0e305f247c0c3c80d0c7848c8b3",
]
),
),
(f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(
f.repeat(column("a"), literal(2)),
pa.array(["HelloHello", "WorldWorld", "!!"]),
),
(
f.replace(column("a"), literal("l"), literal("?")),
pa.array(["He??o", "Wor?d", "!"]),
),
(f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])),
(f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])),
(
f.rpad(column("a"), literal(8)),
pa.array(["Hello ", "World ", "! "]),
),
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
(
f.split_part(column("a"), literal("l"), literal(1)),
pa.array(["He", "Wor", "!"]),
),
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
(
f.translate(column("a"), literal("or"), literal("ld")),
pa.array(["Helll", "Wldld", "!"]),
),
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
(
f.overlay(column("a"), literal("--"), literal(2)),
pa.array(["H--lo", "W--ld", "--"]),
),
(
f.regexp_like(column("a"), literal("(ell|orl)")),
pa.array([True, True, False]),
),
(
f.regexp_match(column("a"), literal("(ell|orl)")),
pa.array([["ell"], ["orl"], None]),
),
(
f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")),
pa.array(["H-o", "W-d", "!"]),
),
],
)
def test_string_functions(df, function, expected_result):
df = df.select(function)
result = df.collect()
Expand Down Expand Up @@ -849,27 +890,30 @@ def test_regr_funcs_sql_2():
assert result_sql[0].column(8) == pa.array([4], type=pa.float64())


@pytest.mark.parametrize("func, expected", [
pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"),
pytest.param(f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"),
pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"),
pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"),
pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"),
pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"),
pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"),
pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"),
pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy")
])
@pytest.mark.parametrize(
"func, expected",
[
pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"),
pytest.param(
f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"
),
pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"),
pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"),
pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"),
pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"),
pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"),
pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"),
pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy"),
],
)
def test_regr_funcs_df(func, expected):

# test case based on `regr_*() basic tests
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1


ctx = SessionContext()

# Create a DataFrame
data = {'column1': [1, 2, 3], 'column2': [2, 4, 6]}
data = {"column1": [1, 2, 3], "column2": [2, 4, 6]}
df = ctx.from_pydict(data, name="test_table")

# Perform the regression function using DataFrame API
Expand Down Expand Up @@ -900,6 +944,8 @@ def test_first_last_value(df):
assert result.column(3) == pa.array(["!"])
assert result.column(4) == pa.array([6])
assert result.column(5) == pa.array([datetime(2020, 7, 2)])
df.show()
assert False


def test_binary_string_functions(df):
Expand Down
37 changes: 28 additions & 9 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,18 +319,15 @@ pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyEx
}
}

#[pyfunction]
pub fn first_value(
expr: PyExpr,
fn add_builder_fns_to_aggregate(
agg_fn: Expr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
// If we initialize the UDAF with order_by directly, then it gets over-written by the builder
let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);

// luckily, I can guarantee initializing a builder with an `order_by` default of empty vec
// Since ExprFuncBuilder::new() is private, we can guarantee initializing
// a builder with an `order_by` default of empty vec
let order_by = order_by
.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>())
.unwrap_or_default();
Expand All @@ -351,8 +348,30 @@ pub fn first_value(
}

#[pyfunction]
pub fn last_value(expr: PyExpr) -> PyExpr {
functions_aggregate::expr_fn::last_value(vec![expr.expr]).into()
pub fn first_value(
expr: PyExpr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
// If we initialize the UDAF with order_by directly, then it gets over-written by the builder
let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);

add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
}

#[pyfunction]
pub fn last_value(
expr: PyExpr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
let agg_fn = functions_aggregate::expr_fn::last_value(vec![expr.expr]);

add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
}

#[pyfunction]
Expand Down

0 comments on commit 616a748

Please sign in to comment.