diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx
index 0ad09dba717..ceae1b148aa 100644
--- a/python/cudf/cudf/_lib/lists.pyx
+++ b/python/cudf/cudf/_lib/lists.pyx
@@ -8,11 +8,9 @@ from libcpp.utility cimport move
 
 from cudf._lib.column cimport Column
 from cudf._lib.pylibcudf.libcudf.column.column cimport column
-from cudf._lib.pylibcudf.libcudf.column.column_view cimport column_view
 from cudf._lib.pylibcudf.libcudf.lists.count_elements cimport (
     count_elements as cpp_count_elements,
 )
-from cudf._lib.pylibcudf.libcudf.lists.extract cimport extract_list_element
 from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport (
     lists_column_view,
 )
@@ -116,37 +114,23 @@ def sort_lists(Column col, bool ascending, str na_position):
 
 @acquire_spill_lock()
 def extract_element_scalar(Column col, size_type index):
-    # shared_ptr required because lists_column_view has no default
-    # ctor
-    cdef shared_ptr[lists_column_view] list_view = (
-        make_shared[lists_column_view](col.view())
+    return Column.from_pylibcudf(
+        pylibcudf.lists.extract_list_element(
+            col.to_pylibcudf(mode="read"),
+            index,
+        )
     )
 
-    cdef unique_ptr[column] c_result
-
-    with nogil:
-        c_result = move(extract_list_element(list_view.get()[0], index))
-
-    result = Column.from_unique_ptr(move(c_result))
-    return result
-
 
 @acquire_spill_lock()
 def extract_element_column(Column col, Column index):
-    cdef shared_ptr[lists_column_view] list_view = (
-        make_shared[lists_column_view](col.view())
+    return Column.from_pylibcudf(
+        pylibcudf.lists.extract_list_element(
+            col.to_pylibcudf(mode="read"),
+            index.to_pylibcudf(mode="read"),
+        )
     )
 
-    cdef column_view index_view = index.view()
-
-    cdef unique_ptr[column] c_result
-
-    with nogil:
-        c_result = move(extract_list_element(list_view.get()[0], index_view))
-
-    result = Column.from_unique_ptr(move(c_result))
-    return result
-
 
 @acquire_spill_lock()
 def contains_scalar(Column col, py_search_key):
diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd
index caa12f41914..53609ba8830 100644
--- a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd
+++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/extract.pxd
@@ -11,10 +11,10 @@ from cudf._lib.pylibcudf.libcudf.types cimport size_type
 
 cdef extern from "cudf/lists/extract.hpp" namespace "cudf::lists" nogil:
     cdef unique_ptr[column] extract_list_element(
-        const lists_column_view,
+        const lists_column_view&,
         size_type
     ) except +
     cdef unique_ptr[column] extract_list_element(
-        const lists_column_view,
-        column_view
+        const lists_column_view&,
+        const column_view&
     ) except +
diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pxd b/python/cudf/cudf/_lib/pylibcudf/lists.pxd
index c9c43751a43..38a479e4791 100644
--- a/python/cudf/cudf/_lib/pylibcudf/lists.pxd
+++ b/python/cudf/cudf/_lib/pylibcudf/lists.pxd
@@ -12,6 +12,10 @@ ctypedef fused ColumnOrScalar:
     Column
     Scalar
 
+ctypedef fused ColumnOrSizeType:
+    Column
+    size_type
+
 cpdef Table explode_outer(Table, size_type explode_column_idx)
 
 cpdef Column concatenate_rows(Table)
@@ -27,3 +31,5 @@ cpdef Column index_of(Column, ColumnOrScalar, bool)
 cpdef Column reverse(Column)
 
 cpdef Column segmented_gather(Column, Column)
+
+cpdef Column extract_list_element(Column, ColumnOrSizeType)
diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pyx b/python/cudf/cudf/_lib/pylibcudf/lists.pyx
index 9c56f1139c6..19c961aa014 100644
--- a/python/cudf/cudf/_lib/pylibcudf/lists.pyx
+++ b/python/cudf/cudf/_lib/pylibcudf/lists.pyx
@@ -17,9 +17,12 @@ from cudf._lib.pylibcudf.libcudf.lists.combine cimport (
     concatenate_null_policy,
     concatenate_rows as cpp_concatenate_rows,
 )
+from cudf._lib.pylibcudf.libcudf.lists.extract cimport (
+    extract_list_element as cpp_extract_list_element,
+)
 from cudf._lib.pylibcudf.libcudf.table.table cimport table
 from cudf._lib.pylibcudf.libcudf.types cimport size_type
-from cudf._lib.pylibcudf.lists cimport ColumnOrScalar
+from cudf._lib.pylibcudf.lists cimport ColumnOrScalar, ColumnOrSizeType
 
 from .column cimport Column, ListColumnView
 from .scalar cimport Scalar
@@ -264,3 +267,29 @@ cpdef Column segmented_gather(Column input, Column gather_map_list):
             list_view2.view(),
         ))
     return Column.from_libcudf(move(c_result))
+
+
+cpdef Column extract_list_element(Column input, ColumnOrSizeType index):
+    """Create a column of extracted list elements.
+
+    Parameters
+    ----------
+    input : Column
+        The input column.
+    index : Union[Column, size_type]
+        The selection index or indices.
+
+    Returns
+    -------
+    Column
+        A new Column with elements extracted.
+    """
+    cdef unique_ptr[column] c_result
+    cdef ListColumnView list_view = input.list_view()
+
+    with nogil:
+        c_result = move(cpp_extract_list_element(
+            list_view.view(),
+            index.view() if ColumnOrSizeType is Column else index,
+        ))
+    return Column.from_libcudf(move(c_result))
diff --git a/python/cudf/cudf/pylibcudf_tests/test_lists.py b/python/cudf/cudf/pylibcudf_tests/test_lists.py
index 0d95579acb3..07ecaed5012 100644
--- a/python/cudf/cudf/pylibcudf_tests/test_lists.py
+++ b/python/cudf/cudf/pylibcudf_tests/test_lists.py
@@ -160,3 +160,24 @@ def test_segmented_gather(test_data):
     expect = pa.array([[8, 9], [14], [0], [0, 0]])
 
     assert_column_eq(expect, res)
+
+
+def test_extract_list_element_scalar(test_data):
+    arr = pa.array(test_data[0][0])
+    plc_column = plc.interop.from_arrow(arr)
+
+    res = plc.lists.extract_list_element(plc_column, 0)
+    expect = pa.compute.list_element(test_data[0][0], 0)
+
+    assert_column_eq(expect, res)
+
+
+def test_extract_list_element_column(test_data):
+    arr = pa.array(test_data[0][0])
+    plc_column = plc.interop.from_arrow(arr)
+    indices = plc.interop.from_arrow(pa.array([0, 1, -4, -1]))
+
+    res = plc.lists.extract_list_element(plc_column, indices)
+    expect = pa.array([0, None, None, 7])
+
+    assert_column_eq(expect, res)