diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx index 0ad09dba717..ceae1b148aa 100644 --- a/python/cudf/cudf/_lib/lists.pyx +++ b/python/cudf/cudf/_lib/lists.pyx @@ -8,11 +8,9 @@ from libcpp.utility cimport move from cudf._lib.column cimport Column from cudf._lib.pylibcudf.libcudf.column.column cimport column -from cudf._lib.pylibcudf.libcudf.column.column_view cimport column_view from cudf._lib.pylibcudf.libcudf.lists.count_elements cimport ( count_elements as cpp_count_elements, ) -from cudf._lib.pylibcudf.libcudf.lists.extract cimport extract_list_element from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport ( lists_column_view, ) @@ -116,37 +114,23 @@ def sort_lists(Column col, bool ascending, str na_position): @acquire_spill_lock() def extract_element_scalar(Column col, size_type index): - # shared_ptr required because lists_column_view has no default - # ctor - cdef shared_ptr[lists_column_view] list_view = ( - make_shared[lists_column_view](col.view()) + return Column.from_pylibcudf( + pylibcudf.lists.extract_list_element( + col.to_pylibcudf(mode="read"), + index, + ) ) - cdef unique_ptr[column] c_result - - with nogil: - c_result = move(extract_list_element(list_view.get()[0], index)) - - result = Column.from_unique_ptr(move(c_result)) - return result - @acquire_spill_lock() def extract_element_column(Column col, Column index): - cdef shared_ptr[lists_column_view] list_view = ( - make_shared[lists_column_view](col.view()) + return Column.from_pylibcudf( + pylibcudf.lists.extract_list_element( + col.to_pylibcudf(mode="read"), + index.to_pylibcudf(mode="read"), + ) ) - cdef column_view index_view = index.view() - - cdef unique_ptr[column] c_result - - with nogil: - c_result = move(extract_list_element(list_view.get()[0], index_view)) - - result = Column.from_unique_ptr(move(c_result)) - return result - @acquire_spill_lock() def contains_scalar(Column col, py_search_key): diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd index caa12f41914..53609ba8830 100644 --- a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd @@ -11,10 +11,10 @@ from cudf._lib.pylibcudf.libcudf.types cimport size_type cdef extern from "cudf/lists/extract.hpp" namespace "cudf::lists" nogil: cdef unique_ptr[column] extract_list_element( - const lists_column_view, + const lists_column_view&, size_type ) except + cdef unique_ptr[column] extract_list_element( - const lists_column_view, - column_view + const lists_column_view&, + const column_view& ) except + diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pxd b/python/cudf/cudf/_lib/pylibcudf/lists.pxd index c9c43751a43..38a479e4791 100644 --- a/python/cudf/cudf/_lib/pylibcudf/lists.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/lists.pxd @@ -12,6 +12,10 @@ ctypedef fused ColumnOrScalar: Column Scalar +ctypedef fused ColumnOrSizeType: + Column + size_type + cpdef Table explode_outer(Table, size_type explode_column_idx) cpdef Column concatenate_rows(Table) @@ -27,3 +31,5 @@ cpdef Column index_of(Column, ColumnOrScalar, bool) cpdef Column reverse(Column) cpdef Column segmented_gather(Column, Column) + +cpdef Column extract_list_element(Column, ColumnOrSizeType) diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pyx b/python/cudf/cudf/_lib/pylibcudf/lists.pyx index 9c56f1139c6..19c961aa014 100644 --- a/python/cudf/cudf/_lib/pylibcudf/lists.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/lists.pyx @@ -17,9 +17,12 @@ from cudf._lib.pylibcudf.libcudf.lists.combine cimport ( concatenate_null_policy, concatenate_rows as cpp_concatenate_rows, ) +from cudf._lib.pylibcudf.libcudf.lists.extract cimport ( + extract_list_element as cpp_extract_list_element, +) from cudf._lib.pylibcudf.libcudf.table.table cimport table from cudf._lib.pylibcudf.libcudf.types cimport size_type -from cudf._lib.pylibcudf.lists cimport ColumnOrScalar +from cudf._lib.pylibcudf.lists cimport ColumnOrScalar, ColumnOrSizeType from .column cimport Column, ListColumnView from .scalar cimport Scalar @@ -264,3 +267,29 @@ cpdef Column segmented_gather(Column input, Column gather_map_list): list_view2.view(), )) return Column.from_libcudf(move(c_result)) + + +cpdef Column extract_list_element(Column input, ColumnOrSizeType index): + """Create a column of extracted list elements. + + Parameters + ---------- + input : Column + The input column. + index : Union[Column, size_type] + The selection index or indices. + + Returns + ------- + Column + A new Column with elements extracted. + """ + cdef unique_ptr[column] c_result + cdef ListColumnView list_view = input.list_view() + + with nogil: + c_result = move(cpp_extract_list_element( + list_view.view(), + index.view() if ColumnOrSizeType is Column else index, + )) + return Column.from_libcudf(move(c_result)) diff --git a/python/cudf/cudf/pylibcudf_tests/test_lists.py b/python/cudf/cudf/pylibcudf_tests/test_lists.py index 0d95579acb3..07ecaed5012 100644 --- a/python/cudf/cudf/pylibcudf_tests/test_lists.py +++ b/python/cudf/cudf/pylibcudf_tests/test_lists.py @@ -160,3 +160,24 @@ def test_segmented_gather(test_data): expect = pa.array([[8, 9], [14], [0], [0, 0]]) assert_column_eq(expect, res) + + +def test_extract_list_element_scalar(test_data): + arr = pa.array(test_data[0][0]) + plc_column = plc.interop.from_arrow(arr) + + res = plc.lists.extract_list_element(plc_column, 0) + expect = pa.compute.list_element(test_data[0][0], 0) + + assert_column_eq(expect, res) + + +def test_extract_list_element_column(test_data): + arr = pa.array(test_data[0][0]) + plc_column = plc.interop.from_arrow(arr) + indices = plc.interop.from_arrow(pa.array([0, 1, -4, -1])) + + res = plc.lists.extract_list_element(plc_column, indices) + expect = pa.array([0, None, None, 7]) + + assert_column_eq(expect, res)