diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index d0514f89..7e772589 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -200,19 +200,62 @@ def test_math_functions(): def test_array_functions(): - data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]] + data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [np.array(data, dtype=object)], names=["arr"] ) df = ctx.create_dataframe([[batch]]) + def py_indexof(arr, v): + try: + return arr.index(v) + 1 + except ValueError: + return np.nan + + def py_arr_remove(arr, v, n=None): + new_arr = arr[:] + found = 0 + while found != n: + try: + new_arr.remove(v) + found += 1 + except ValueError: + break + + return new_arr + + def py_arr_replace(arr, from_, to, n=None): + new_arr = arr[:] + found = 0 + while found != n: + try: + idx = new_arr.index(from_) + new_arr[idx] = to + found += 1 + except ValueError: + break + + return new_arr + col = column("arr") test_items = [ [ f.array_append(col, literal(99.0)), lambda: [np.append(arr, 99.0) for arr in data], ], + [ + f.array_push_back(col, literal(99.0)), + lambda: [np.append(arr, 99.0) for arr in data], + ], + [ + f.list_append(col, literal(99.0)), + lambda: [np.append(arr, 99.0) for arr in data], + ], + [ + f.list_push_back(col, literal(99.0)), + lambda: [np.append(arr, 99.0) for arr in data], + ], [ f.array_concat(col, col), lambda: [np.concatenate([arr, arr]) for arr in data], @@ -253,12 +296,174 @@ def test_array_functions(): f.list_length(col), lambda: [len(r) for r in data], ], + [ + f.array_has(col, literal(1.0)), + lambda: [1.0 in r for r in data], + ], + [ + f.array_has_all( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ], + [ + f.array_has_any( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ], + [ + f.array_position(col, literal(1.0)), + lambda: [py_indexof(r, 1.0) for r in data], + ], + [ + f.array_indexof(col, literal(1.0)), + lambda: [py_indexof(r, 1.0) for r in data], + ], + [ + f.list_position(col, literal(1.0)), + lambda: [py_indexof(r, 1.0) for r in data], + ], + [ + f.list_indexof(col, literal(1.0)), + lambda: [py_indexof(r, 1.0) for r in data], + ], + [ + f.array_positions(col, literal(1.0)), + lambda: [ + [i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data + ], + ], + [ + f.list_positions(col, literal(1.0)), + lambda: [ + [i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data + ], + ], + [ + f.array_ndims(col), + lambda: [np.array(r).ndim for r in data], + ], + [ + f.list_ndims(col), + lambda: [np.array(r).ndim for r in data], + ], + [ + f.array_prepend(literal(99.0), col), + lambda: [np.insert(arr, 0, 99.0) for arr in data], + ], + [ + f.array_push_front(literal(99.0), col), + lambda: [np.insert(arr, 0, 99.0) for arr in data], + ], + [ + f.list_prepend(literal(99.0), col), + lambda: [np.insert(arr, 0, 99.0) for arr in data], + ], + [ + f.list_push_front(literal(99.0), col), + lambda: [np.insert(arr, 0, 99.0) for arr in data], + ], + [ + f.array_pop_back(col), + lambda: [arr[:-1] for arr in data], + ], + [ + f.array_pop_front(col), + lambda: [arr[1:] for arr in data], + ], + [ + f.array_remove(col, literal(3.0)), + lambda: [py_arr_remove(arr, 3.0, 1) for arr in data], + ], + [ + f.list_remove(col, literal(3.0)), + lambda: [py_arr_remove(arr, 3.0, 1) for arr in data], + ], + [ + f.array_remove_n(col, literal(3.0), literal(2)), + lambda: [py_arr_remove(arr, 3.0, 2) for arr in data], + ], + [ + f.list_remove_n(col, literal(3.0), literal(2)), + lambda: [py_arr_remove(arr, 3.0, 2) for arr in data], + ], + [ + f.array_remove_all(col, literal(3.0)), + lambda: [py_arr_remove(arr, 3.0) for arr in data], + ], + [ + f.list_remove_all(col, literal(3.0)), + lambda: [py_arr_remove(arr, 3.0) for arr in data], + ], + [ + f.array_repeat(col, literal(2)), + lambda: [[arr] * 2 for arr in data], + ], + [ + f.array_replace(col, literal(3.0), literal(4.0)), + lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], + ], + [ + f.list_replace(col, literal(3.0), literal(4.0)), + lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], + ], + [ + f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)), + lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], + ], + [ + f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)), + lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data], + ], + [ + f.array_replace_all(col, literal(3.0), literal(4.0)), + lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data], + ], + [ + f.list_replace_all(col, literal(3.0), literal(4.0)), + lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data], + ], + [ + f.array_slice(col, literal(2), literal(4)), + lambda: [arr[1:4] for arr in data], + ], + [ + f.list_slice(col, literal(-1), literal(2)), + lambda: [arr[-1:2] for arr in data], + ], ] for stmt, py_expr in test_items: - query_result = df.select(stmt).collect()[0].column(0).tolist() + query_result = df.select(stmt).collect()[0].column(0) + for a, b in zip(query_result, py_expr()): + np.testing.assert_array_almost_equal( + np.array(a.as_py(), dtype=float), np.array(b, dtype=float) + ) + + obj_test_items = [ + [ + f.array_to_string(col, literal(",")), + lambda: [",".join([str(int(v)) for v in r]) for r in data], + ], + [ + f.array_join(col, literal(",")), + lambda: [",".join([str(int(v)) for v in r]) for r in data], + ], + [ + f.list_to_string(col, literal(",")), + lambda: [",".join([str(int(v)) for v in r]) for r in data], + ], + [ + f.list_join(col, literal(",")), + lambda: [",".join([str(int(v)) for v in r]) for r in data], + ], + ] + + for stmt, py_expr in obj_test_items: + query_result = np.array(df.select(stmt).collect()[0].column(0)) for a, b in zip(query_result, py_expr()): - np.testing.assert_array_almost_equal(a, b) + assert a == b def test_string_functions(df): diff --git a/src/functions.rs b/src/functions.rs index 3dc5322a..e3c485a3 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -359,6 +359,9 @@ scalar_function!(decode, Decode); // Array Functions scalar_function!(array_append, ArrayAppend); +scalar_function!(array_push_back, ArrayAppend); +scalar_function!(list_append, ArrayAppend); +scalar_function!(list_push_back, ArrayAppend); scalar_function!(array_concat, ArrayConcat); scalar_function!(array_cat, ArrayConcat); scalar_function!(array_dims, ArrayDims); @@ -369,6 +372,42 @@ scalar_function!(list_element, ArrayElement); scalar_function!(list_extract, ArrayElement); scalar_function!(array_length, ArrayLength); scalar_function!(list_length, ArrayLength); +scalar_function!(array_has, ArrayHas); +scalar_function!(array_has_all, ArrayHasAll); +scalar_function!(array_has_any, ArrayHasAny); +scalar_function!(array_position, ArrayPosition); +scalar_function!(array_indexof, ArrayPosition); +scalar_function!(list_position, ArrayPosition); +scalar_function!(list_indexof, ArrayPosition); +scalar_function!(array_positions, ArrayPositions); +scalar_function!(list_positions, ArrayPositions); +scalar_function!(array_to_string, ArrayToString); +scalar_function!(array_join, ArrayToString); +scalar_function!(list_to_string, ArrayToString); +scalar_function!(list_join, ArrayToString); +scalar_function!(array_ndims, ArrayNdims); +scalar_function!(list_ndims, ArrayNdims); +scalar_function!(array_prepend, ArrayPrepend); +scalar_function!(array_push_front, ArrayPrepend); +scalar_function!(list_prepend, ArrayPrepend); +scalar_function!(list_push_front, ArrayPrepend); +scalar_function!(array_pop_back, ArrayPopBack); +scalar_function!(array_pop_front, ArrayPopFront); +scalar_function!(array_remove, ArrayRemove); +scalar_function!(list_remove, ArrayRemove); +scalar_function!(array_remove_n, ArrayRemoveN); +scalar_function!(list_remove_n, ArrayRemoveN); +scalar_function!(array_remove_all, ArrayRemoveAll); +scalar_function!(list_remove_all, ArrayRemoveAll); +scalar_function!(array_repeat, ArrayRepeat); +scalar_function!(array_replace, ArrayReplace); +scalar_function!(list_replace, ArrayReplace); +scalar_function!(array_replace_n, ArrayReplaceN); +scalar_function!(list_replace_n, ArrayReplaceN); +scalar_function!(array_replace_all, ArrayReplaceAll); +scalar_function!(list_replace_all, ArrayReplaceAll); +scalar_function!(array_slice, ArraySlice); +scalar_function!(list_slice, ArraySlice); aggregate_function!(approx_distinct, ApproxDistinct); aggregate_function!(approx_median, ApproxMedian); @@ -562,6 +601,9 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { // Array Functions m.add_wrapped(wrap_pyfunction!(array_append))?; + m.add_wrapped(wrap_pyfunction!(array_push_back))?; + m.add_wrapped(wrap_pyfunction!(list_append))?; + m.add_wrapped(wrap_pyfunction!(list_push_back))?; m.add_wrapped(wrap_pyfunction!(array_concat))?; m.add_wrapped(wrap_pyfunction!(array_cat))?; m.add_wrapped(wrap_pyfunction!(array_dims))?; @@ -572,6 +614,42 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(list_extract))?; m.add_wrapped(wrap_pyfunction!(array_length))?; m.add_wrapped(wrap_pyfunction!(list_length))?; + m.add_wrapped(wrap_pyfunction!(array_has))?; + m.add_wrapped(wrap_pyfunction!(array_has_all))?; + m.add_wrapped(wrap_pyfunction!(array_has_any))?; + m.add_wrapped(wrap_pyfunction!(array_position))?; + m.add_wrapped(wrap_pyfunction!(array_indexof))?; + m.add_wrapped(wrap_pyfunction!(list_position))?; + m.add_wrapped(wrap_pyfunction!(list_indexof))?; + m.add_wrapped(wrap_pyfunction!(array_positions))?; + m.add_wrapped(wrap_pyfunction!(list_positions))?; + m.add_wrapped(wrap_pyfunction!(array_to_string))?; + m.add_wrapped(wrap_pyfunction!(array_join))?; + m.add_wrapped(wrap_pyfunction!(list_to_string))?; + m.add_wrapped(wrap_pyfunction!(list_join))?; + m.add_wrapped(wrap_pyfunction!(array_ndims))?; + m.add_wrapped(wrap_pyfunction!(list_ndims))?; + m.add_wrapped(wrap_pyfunction!(array_prepend))?; + m.add_wrapped(wrap_pyfunction!(array_push_front))?; + m.add_wrapped(wrap_pyfunction!(list_prepend))?; + m.add_wrapped(wrap_pyfunction!(list_push_front))?; + m.add_wrapped(wrap_pyfunction!(array_pop_back))?; + m.add_wrapped(wrap_pyfunction!(array_pop_front))?; + m.add_wrapped(wrap_pyfunction!(array_remove))?; + m.add_wrapped(wrap_pyfunction!(list_remove))?; + m.add_wrapped(wrap_pyfunction!(array_remove_n))?; + m.add_wrapped(wrap_pyfunction!(list_remove_n))?; + m.add_wrapped(wrap_pyfunction!(array_remove_all))?; + m.add_wrapped(wrap_pyfunction!(list_remove_all))?; + m.add_wrapped(wrap_pyfunction!(array_repeat))?; + m.add_wrapped(wrap_pyfunction!(array_replace))?; + m.add_wrapped(wrap_pyfunction!(list_replace))?; + m.add_wrapped(wrap_pyfunction!(array_replace_n))?; + m.add_wrapped(wrap_pyfunction!(list_replace_n))?; + m.add_wrapped(wrap_pyfunction!(array_replace_all))?; + m.add_wrapped(wrap_pyfunction!(list_replace_all))?; + m.add_wrapped(wrap_pyfunction!(array_slice))?; + m.add_wrapped(wrap_pyfunction!(list_slice))?; Ok(()) }