Skip to content

Commit

Permalink
Enable multiclass svm for sparse input (#5588)
Browse files Browse the repository at this point in the history
This commit enables multiclass SVM for sparse input. Previously this was deactivated as the `input_to_host_array` functionality does not support sparse arrays, but the data has to be piped through sklearn classes which requires host data. 

@dantegd , this is a local workaround to enable a `input_to_host_array` for sparse data without the complexity of providing the whole functionality of that function. Please have a look whether this is an acceptable solution for this use case.

FYI, @tfeher

Authors:
  - Malte Förster (https://github.com/mfoerste4)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Simon Adorf (https://github.com/csadorf)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Simon Adorf (https://github.com/csadorf)

URL: #5588
  • Loading branch information
mfoerste4 authored Nov 21, 2023
1 parent f79d40f commit 1570ed7
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 18 deletions.
12 changes: 10 additions & 2 deletions cpp/src/svm/kernelcache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,16 @@ class BatchCache : public raft::cache::Cache<math_t> {
RAFT_CUDA_TRY(cudaMemsetAsync(tmp_buffer, 0, n_ws * 2 * sizeof(int), stream));

// Init cub buffers
cub::DeviceRadixSort::SortKeys(
NULL, d_temp_storage_size, tmp_buffer, tmp_buffer, n_ws, 0, sizeof(int) * 8, stream);
cub::DeviceRadixSort::SortPairs(NULL,
d_temp_storage_size,
tmp_buffer,
tmp_buffer,
tmp_buffer,
tmp_buffer,
n_ws,
0,
sizeof(int) * 8,
stream);
d_temp_storage.resize(d_temp_storage_size, stream);
}

Expand Down
4 changes: 3 additions & 1 deletion python/cuml/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@

from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.input_utils import input_to_host_array
from cuml.internals.input_utils import input_to_host_array_with_sparse_support

from cuml.internals.memory_utils import rmm_cupy_ary
from cuml.internals.memory_utils import set_global_output_type
Expand Down Expand Up @@ -59,6 +60,7 @@
"has_scipy",
"input_to_cuml_array",
"input_to_host_array",
"input_to_host_array_with_sparse_support",
"rmm_cupy_ary",
"set_global_output_type",
"using_device_type",
Expand Down
14 changes: 14 additions & 0 deletions python/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,20 @@ def input_to_host_array(
return out_data._replace(array=out_data.array.to_output("numpy"))


def input_to_host_array_with_sparse_support(X):
_array_type, is_sparse = determine_array_type_full(X)
if is_sparse:
if _array_type == "cupy":
return SparseCumlArray(X).to_output(output_type="scipy")
elif _array_type == "cuml":
return X.to_output(output_type="scipy")
elif _array_type == "numpy":
return X
else:
raise ValueError(f"Unsupported sparse array type: {_array_type}.")
return input_to_host_array(X).array


def convert_dtype(X, to_dtype=np.float32, legacy=True, safe_dtype=True):
"""
Convert X to be of dtype `dtype`, raising a TypeError
Expand Down
19 changes: 15 additions & 4 deletions python/cuml/multiclass/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
from cuml.internals.import_utils import has_sklearn
from cuml.internals.mixins import ClassifierMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_host_array
from cuml.common import (
input_to_host_array,
input_to_host_array_with_sparse_support,
)
from cuml.internals.input_utils import (
input_to_cupy_array,
determine_array_type_full,
)
from cuml.internals.array_sparse import SparseCumlArray
from cuml.internals import _deprecate_pos_args


Expand Down Expand Up @@ -142,7 +150,9 @@ def fit(self, X, y) -> "MulticlassClassifier":
+ ", must be one of "
'{"ovr", "ovo"}'
)
X = input_to_host_array(X).array

X = input_to_host_array_with_sparse_support(X)

y = input_to_host_array(y).array
with cuml.internals.exit_internal_api():
self.multiclass_estimator.fit(X, y)
Expand All @@ -160,7 +170,8 @@ def predict(self, X) -> CumlArray:
"""
Predict using multi class classifier.
"""
X = input_to_host_array(X).array
X = input_to_host_array_with_sparse_support(X)

with cuml.internals.exit_internal_api():
return self.multiclass_estimator.predict(X)

Expand All @@ -177,7 +188,7 @@ def decision_function(self, X) -> CumlArray:
"""
Calculate the decision function.
"""
X = input_to_host_array(X).array
X = input_to_host_array_with_sparse_support(X)
with cuml.internals.exit_internal_api():
return self.multiclass_estimator.decision_function(X)

Expand Down
10 changes: 4 additions & 6 deletions python/cuml/svm/svc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ from cuml.common.doc_utils import generate_docstring
from cuml.internals.logger import warn
from pylibraft.common.handle cimport handle_t
from pylibraft.common.interruptible import cuda_interruptible
from cuml.common import input_to_cuml_array, input_to_host_array
from cuml.common import input_to_cuml_array, input_to_host_array, input_to_host_array_with_sparse_support
from cuml.internals.input_utils import input_to_cupy_array, determine_array_type_full
from cuml.preprocessing import LabelEncoder
from libcpp cimport nullptr
Expand Down Expand Up @@ -449,7 +449,7 @@ class SVC(SVMBase,

# Currently CalibratedClassifierCV expects data on the host, see
# https://github.com/rapidsai/cuml/issues/2608
X = input_to_host_array(X).array
X = input_to_host_array_with_sparse_support(X)
y = input_to_host_array(y).array

if not has_sklearn():
Expand Down Expand Up @@ -485,8 +485,6 @@ class SVC(SVMBase,
return self._fit_proba(X, y, sample_weight)

if self.n_classes_ > 2:
if is_sparse:
raise ValueError("Multiclass SVM does not support sparse input.")
return self._fit_multiclass(X, y, sample_weight)

if is_sparse:
Expand Down Expand Up @@ -594,7 +592,7 @@ class SVC(SVMBase,
if self.probability:
self._check_is_fitted('prob_svc')

X = input_to_host_array(X).array
X = input_to_host_array_with_sparse_support(X)

with cuml.internals.exit_internal_api():
preds = self.prob_svc.predict(X)
Expand Down Expand Up @@ -628,7 +626,7 @@ class SVC(SVMBase,
if self.probability:
self._check_is_fitted('prob_svc')

X = input_to_host_array(X).array
X = input_to_host_array_with_sparse_support(X)

# Exit the internal API when calling sklearn code (forces numpy
# conversion)
Expand Down
3 changes: 0 additions & 3 deletions python/cuml/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,6 @@ def assert_second_model(pickled_model, X):
def test_svc_pickle(tmpdir, datatype, params, multiclass, sparse):
result = {}

if sparse and multiclass:
pytest.skip("Multiclass SVC does not support sparse input")

if sparse and params["probability"]:
pytest.skip("Probabilistic SVC does not support sparse input")

Expand Down
9 changes: 7 additions & 2 deletions python/cuml/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
np = cpu_only_import("numpy")
cuda = gpu_only_import_from("numba", "cuda")


cudf = gpu_only_import("cudf")
scipy_sparse = cpu_only_import("scipy.sparse")

IS_ARM = platform.processor() == "aarch64"

Expand Down Expand Up @@ -176,13 +176,18 @@ def test_svm_skl_cmp_datasets(params, dataset, n_rows, n_cols):


@pytest.mark.parametrize("params", [{"kernel": "rbf", "C": 1, "gamma": 1}])
@pytest.mark.parametrize("sparse", [True, False])
def test_svm_skl_cmp_multiclass(
params, dataset="classification2", n_rows=100, n_cols=6
params, sparse, dataset="classification2", n_rows=100, n_cols=6
):
X_train, X_test, y_train, y_test = make_dataset(
dataset, n_rows, n_cols, n_classes=3, n_informative=6
)

if sparse:
X_train = scipy_sparse.csr_matrix(X_train)
X_test = scipy_sparse.csr_matrix(X_test)

# Default to numpy for testing
with cuml.using_output_type("numpy"):

Expand Down

0 comments on commit 1570ed7

Please sign in to comment.