diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index fc775aeb..fa9a495d 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -220,6 +220,15 @@ def py_arr_replace(arr, from_, to, n=None): return new_arr + def py_arr_resize(arr, size, value): + arr = np.asarray(arr) + return np.pad( + arr, + [(0, size - arr.shape[0])], + "constant", + constant_values=value, + ) + def py_flatten(arr): result = [] for elem in arr: @@ -447,6 +456,14 @@ def py_flatten(arr): f.list_except(col, literal([3.0])), lambda: [np.setdiff1d(arr, [3.0]) for arr in data], ], + [ + f.array_resize(col, literal(10), literal(0.0)), + lambda: [py_arr_resize(arr, 10, 0.0) for arr in data], + ], + [ + f.list_resize(col, literal(10), literal(0.0)), + lambda: [py_arr_resize(arr, 10, 0.0) for arr in data], + ], [f.flatten(literal(data)), lambda: [py_flatten(data)]], [ f.range(literal(1), literal(5), literal(2)), diff --git a/src/functions.rs b/src/functions.rs index ed1b5be0..666e1ec3 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -453,6 +453,8 @@ scalar_function!(array_union, ArrayUnion); scalar_function!(list_union, ArrayUnion); scalar_function!(array_except, ArrayExcept); scalar_function!(list_except, ArrayExcept); +scalar_function!(array_resize, ArrayResize); +scalar_function!(list_resize, ArrayResize); scalar_function!(flatten, Flatten); aggregate_function!(approx_distinct, ApproxDistinct); @@ -679,6 +681,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(list_union))?; m.add_wrapped(wrap_pyfunction!(array_except))?; m.add_wrapped(wrap_pyfunction!(list_except))?; + m.add_wrapped(wrap_pyfunction!(array_resize))?; + m.add_wrapped(wrap_pyfunction!(list_resize))?; m.add_wrapped(wrap_pyfunction!(array_join))?; m.add_wrapped(wrap_pyfunction!(list_to_string))?; m.add_wrapped(wrap_pyfunction!(list_join))?;