Skip to content

Commit

Permalink
Update pylibcudf testing utilities (rapidsai#15772)
Browse files Browse the repository at this point in the history
Cleans up some testing utilities for pylibcudf as suggested in rapidsai#15418 (comment).

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Bradley Dice (https://github.com/bdice)

URL: rapidsai#15772
  • Loading branch information
brandon-b-miller authored May 30, 2024
1 parent e95894f commit c268fc1
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 33 deletions.
42 changes: 29 additions & 13 deletions python/cudf/cudf/pylibcudf_tests/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from typing import Optional
from typing import Optional, Union

import pyarrow as pa
import pytest
Expand All @@ -24,27 +24,43 @@ def metadata_from_arrow_array(
return metadata


def assert_column_eq(plc_column: plc.Column, pa_array: pa.Array) -> None:
"""Verify that the pylibcudf array and PyArrow array are equal."""
def assert_column_eq(
lhs: Union[pa.Array, plc.Column], rhs: Union[pa.Array, plc.Column]
) -> None:
"""Verify that a pylibcudf array and PyArrow array are equal."""
# Nested types require children metadata to be passed to the conversion function.
plc_pa = plc.interop.to_arrow(
plc_column, metadata=metadata_from_arrow_array(pa_array)
)
if isinstance(lhs, (pa.Array, pa.ChunkedArray)) and isinstance(
rhs, plc.Column
):
rhs = plc.interop.to_arrow(
rhs, metadata=metadata_from_arrow_array(lhs)
)
elif isinstance(lhs, plc.Column) and isinstance(
rhs, (pa.Array, pa.ChunkedArray)
):
lhs = plc.interop.to_arrow(
lhs, metadata=metadata_from_arrow_array(rhs)
)
else:
raise ValueError(
"One of the inputs must be a Column and the other an Array"
)

if isinstance(lhs, pa.ChunkedArray):
lhs = lhs.combine_chunks()
if isinstance(rhs, pa.ChunkedArray):
rhs = rhs.combine_chunks()

if isinstance(plc_pa, pa.ChunkedArray):
plc_pa = plc_pa.combine_chunks()
if isinstance(pa_array, pa.ChunkedArray):
pa_array = pa_array.combine_chunks()
assert plc_pa.equals(pa_array)
assert lhs.equals(rhs)


def assert_table_eq(plc_table: plc.Table, pa_table: pa.Table) -> None:
"""Verify that the pylibcudf array and PyArrow array are equal."""
"""Verify that a pylibcudf table and PyArrow table are equal."""
plc_shape = (plc_table.num_rows(), plc_table.num_columns())
assert plc_shape == pa_table.shape

for plc_col, pa_col in zip(plc_table.columns(), pa_table.columns):
assert_column_eq(plc_col, pa_col)
assert_column_eq(pa_col, plc_col)


def cudf_raises(expected_exception: BaseException, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def test_from_cuda_array_interface(valid_column):
)
expect = valid_column

assert_column_eq(col, expect)
assert_column_eq(expect, col)
14 changes: 7 additions & 7 deletions python/cudf/cudf/pylibcudf_tests/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def test_copy_range_in_place(
),
pa_target_column,
)
assert_column_eq(mutable_target_column, expected)
assert_column_eq(expected, mutable_target_column)


def test_copy_range_in_place_out_of_bounds(
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_copy_range(
),
pa_target_column,
)
assert_column_eq(result, expected)
assert_column_eq(expected, result)
else:
with pytest.raises(TypeError):
plc.copying.copy_range(
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_shift(
expected = pa.concat_arrays(
[pa.array([pa_source_scalar] * shift), pa_target_column[:-shift]]
)
assert_column_eq(result, expected)
assert_column_eq(expected, result)
else:
with pytest.raises(TypeError):
plc.copying.shift(target_column, shift, source_scalar)
Expand All @@ -550,7 +550,7 @@ def test_slice_column(target_column, pa_target_column):
lower_bounds = bounds[::2]
result = plc.copying.slice(target_column, bounds)
for lb, ub, slice_ in zip(lower_bounds, upper_bounds, result):
assert_column_eq(slice_, pa_target_column[lb:ub])
assert_column_eq(pa_target_column[lb:ub], slice_)


def test_slice_column_wrong_length(target_column):
Expand Down Expand Up @@ -582,7 +582,7 @@ def test_split_column(target_column, pa_target_column):
lower_bounds = [0] + upper_bounds[:-1]
result = plc.copying.split(target_column, upper_bounds)
for lb, ub, split in zip(lower_bounds, upper_bounds, result):
assert_column_eq(split, pa_target_column[lb:ub])
assert_column_eq(pa_target_column[lb:ub], split)


def test_split_column_decreasing(target_column):
Expand Down Expand Up @@ -622,7 +622,7 @@ def test_copy_if_else_column_column(
pa_target_column,
pa_other_column,
)
assert_column_eq(result, expected)
assert_column_eq(expected, result)


def test_copy_if_else_wrong_type(target_column, mask):
Expand Down Expand Up @@ -699,7 +699,7 @@ def test_copy_if_else_column_scalar(
pa_mask,
*pa_args,
)
assert_column_eq(result, expected)
assert_column_eq(expected, result)


def test_boolean_mask_scatter_from_table(
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/pylibcudf_tests/test_string_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ def test_to_upper(string_col):
plc_col = plc.interop.from_arrow(string_col)
got = plc.strings.case.to_upper(plc_col)
expected = pa.compute.utf8_upper(string_col)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_to_lower(string_col):
plc_col = plc.interop.from_arrow(string_col)
got = plc.strings.case.to_lower(plc_col)
expected = pa.compute.utf8_lower(string_col)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_swapcase(string_col):
plc_col = plc.interop.from_arrow(string_col)
got = plc.strings.case.swapcase(plc_col)
expected = pa.compute.utf8_swapcase(string_col)
assert_column_eq(got, expected)
assert_column_eq(expected, got)
18 changes: 9 additions & 9 deletions python/cudf/cudf/pylibcudf_tests/test_string_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_find(pa_data_col, plc_data_col, pa_target_scalar, plc_target_scalar):
type=pa.int32(),
)

assert_column_eq(got, expected)
assert_column_eq(expected, got)


def colwise_apply(pa_data_col, pa_target_col, operator):
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_find_column(pa_data_col, pa_target_col, plc_data_col, plc_target_col):
)

got = plc.strings.find.find(plc_data_col, plc_target_col, 0)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_rfind(pa_data_col, plc_data_col, pa_target_scalar, plc_target_scalar):
Expand All @@ -192,7 +192,7 @@ def test_rfind(pa_data_col, plc_data_col, pa_target_scalar, plc_target_scalar):
type=pa.int32(),
)

assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_contains(
Expand All @@ -211,7 +211,7 @@ def test_contains(
type=pa.bool_(),
)

assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_contains_column(
Expand All @@ -221,7 +221,7 @@ def test_contains_column(
pa_data_col, pa_target_col, lambda st, target: target in st
)
got = plc.strings.find.contains(plc_data_col, plc_target_col)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_starts_with(
Expand All @@ -230,7 +230,7 @@ def test_starts_with(
py_target = pa_target_scalar.as_py()
got = plc.strings.find.starts_with(plc_data_col, plc_target_scalar)
expected = pa.compute.starts_with(pa_data_col, py_target)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_starts_with_column(
Expand All @@ -240,7 +240,7 @@ def test_starts_with_column(
pa_data_col, pa_target_col, lambda st, target: st.startswith(target)
)
got = plc.strings.find.starts_with(plc_data_col, plc_target_col)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_ends_with(
Expand All @@ -249,7 +249,7 @@ def test_ends_with(
py_target = pa_target_scalar.as_py()
got = plc.strings.find.ends_with(plc_data_col, plc_target_scalar)
expected = pa.compute.ends_with(pa_data_col, py_target)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_ends_with_column(
Expand All @@ -259,4 +259,4 @@ def test_ends_with_column(
pa_data_col, pa_target_col, lambda st, target: st.endswith(target)
)
got = plc.strings.find.ends_with(plc_data_col, plc_target_col)
assert_column_eq(got, expected)
assert_column_eq(expected, got)

0 comments on commit c268fc1

Please sign in to comment.