Skip to content

Commit

Permalink
Add dedicated handling for cudf.pandas wrapped Numpy arrays (#5861)
Browse files Browse the repository at this point in the history
Without the dedicated if branch the wrapped arrays end up being treated like a cupy array because they define `__cuda_array_interface__`.

Is this the right place to change this? I don't know anything about the input handling cuml does, but it seems like adding a special case here leads to the desired result.

Is returning a real numpy array what we want to do? Returning a cudf.pands wrapped numpy array feels like the even better thing to do? But I couldn't work out how to do that :-/

A downside of this approach is that we import a private part of cudf. However, I don't know if there is a way to avoid this if we want to use `isinstance`. We could use something like `hasattr(X, "_fsproxy_fast_type")` (or some other special method `cudf.pandas` defines).

A pro for using the special method check would be that it has no side effects. Importing `cudf.pandas._wrappers` has the side effect of turning on pandas compatibility mode. This leads to a test failure in `LabelEncoder`. Thinking about it, I think as a user of cuml I'd be surprised that it messes with the cudf settings for me. So I think I will switch to using `hasattr`.

Does someone know if there is an officially supported way to test "is this a cudf.pandas wrapped object?"? I had a quick look at the cudf docs but they are sparse when it comes to developer topics like this. An official way to check this might be even better than relying on special methods existing.

Closes #5784

* [x] add non regression test
* [ ] add second run of `pytest` with `pytest -p cudf.pandas ...` to execute regression test
  * As a first pass I think it would be neat to be able to run the whole test suite with cudf.pandas enabled
  * this is a bit tricky because other tests might start failing when executed with cudf.pandas enabled

Authors:
  - Tim Head (https://github.com/betatim)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #5861
  • Loading branch information
betatim authored May 2, 2024
1 parent 48d9f3b commit 5754ec4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
6 changes: 6 additions & 0 deletions python/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def get_supported_input_type(X):
if isinstance(X, CudfIndex):
return CudfIndex

# A cudf.pandas wrapped Numpy array defines `__cuda_array_interface__`
# which means without this we'd always return a cupy array. We don't want
# to match wrapped cupy arrays, they get dealt with later
if getattr(X, "_fsproxy_slow_type", None) is np.ndarray:
return np.ndarray

try:
if numba_cuda.devicearray.is_cuda_ndarray(X):
return numba_cuda.devicearray.DeviceNDArrayBase
Expand Down
17 changes: 16 additions & 1 deletion python/cuml/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,6 +30,7 @@
import os
import subprocess
import pandas as pd
import cudf.pandas

from cuml.internals.safe_imports import cpu_only_import

Expand Down Expand Up @@ -169,6 +170,10 @@ def pytest_collection_modifyitems(config, items):


def pytest_configure(config):
config.addinivalue_line(
"markers",
"cudf_pandas: mark test as requiring the cudf.pandas wrapper",
)
cp.cuda.set_allocator(None)
# max_gpu_memory: Capacity of the GPU memory in GB
pytest.max_gpu_memory = get_gpu_memory()
Expand All @@ -186,6 +191,16 @@ def pytest_configure(config):
hypothesis.settings.load_profile("unit")


def pytest_pyfunc_call(pyfuncitem):
"""Skip tests that require the cudf.pandas accelerator
Tests marked with `@pytest.mark.cudf_pandas` will only be run if the
cudf.pandas accelerator is enabled via the `cudf.pandas` plugin.
"""
if "cudf_pandas" in pyfuncitem.keywords and not cudf.pandas.LOADED:
pytest.skip("Test requires cudf.pandas accelerator")


@pytest.fixture(scope="module")
def nlp_20news():
try:
Expand Down
20 changes: 19 additions & 1 deletion python/cuml/tests/test_input_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,9 @@
# limitations under the License.
#

import numpy as np
from pandas import Series as pdSeries
from cuml.manifold import umap
from cuml.internals.safe_imports import cpu_only_import_from
from cuml.internals.safe_imports import gpu_only_import_from
from cuml.internals.input_utils import convert_dtype
Expand All @@ -24,6 +26,7 @@
from cuml.common import input_to_cuml_array, CumlArray
from cuml.internals.safe_imports import cpu_only_import
import pytest
import pandas as pd

from cuml.internals.safe_imports import gpu_only_import

Expand Down Expand Up @@ -442,3 +445,18 @@ def test_tocupy_missing_values_handling():
array, n_rows, n_cols, dtype = input_to_cupy_array(
df, fail_on_null=True
)


@pytest.mark.cudf_pandas
def test_numpy_output():
# Check that a Numpy array is used as output when a cudf.pandas wrapped
# Numpy array is passed in.
# Non regression test for issue #5784
df = pd.DataFrame({"a": range(5), "b": range(5)})
X = df.values

reducer = umap.UMAP()

# Check that this is a cudf.pandas wrapped array
assert hasattr(X, "_fsproxy_fast_type")
assert isinstance(reducer.fit_transform(X), np.ndarray)

0 comments on commit 5754ec4

Please sign in to comment.