diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index 7e772589..eb37692b 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -238,6 +238,15 @@ def py_arr_replace(arr, from_, to, n=None): return new_arr + def py_flatten(arr): + result = [] + for elem in arr: + if isinstance(elem, list): + result.extend(py_flatten(elem)) + else: + result.append(elem) + return result + col = column("arr") test_items = [ [ @@ -432,6 +441,7 @@ def py_arr_replace(arr, from_, to, n=None): f.list_slice(col, literal(-1), literal(2)), lambda: [arr[-1:2] for arr in data], ], + [f.flatten(literal(data)), lambda: [py_flatten(data)]], ] for stmt, py_expr in test_items: diff --git a/src/functions.rs b/src/functions.rs index bb204c35..5a558de0 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -409,6 +409,7 @@ scalar_function!(array_replace_all, ArrayReplaceAll); scalar_function!(list_replace_all, ArrayReplaceAll); scalar_function!(array_slice, ArraySlice); scalar_function!(list_slice, ArraySlice); +scalar_function!(flatten, Flatten); aggregate_function!(approx_distinct, ApproxDistinct); aggregate_function!(approx_median, ApproxMedian); @@ -651,6 +652,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(list_replace_all))?; m.add_wrapped(wrap_pyfunction!(array_slice))?; m.add_wrapped(wrap_pyfunction!(list_slice))?; + m.add_wrapped(wrap_pyfunction!(flatten))?; Ok(()) }