Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow estimators to accept any dtype #5888

Merged
merged 19 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions python/cuml/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ class DBSCAN(Base,
if self.max_mbytes_per_batch is None:
self.max_mbytes_per_batch = 0

def _fit(self, X, out_dtype, opg, sample_weight) -> "DBSCAN":
def _fit(self, X, out_dtype, opg, sample_weight,
convert_dtype=True) -> "DBSCAN":
"""
Protected auxiliary function for `fit`. Takes an additional parameter
opg that is set to `False` for SG, `True` for OPG (multi-GPU)
Expand All @@ -268,8 +269,13 @@ class DBSCAN(Base,

IF GPUBUILD == 1:
X_m, n_rows, n_cols, self.dtype = \
input_to_cuml_array(X, order='C',
check_dtype=[np.float32, np.float64])
input_to_cuml_array(
X,
order='C',
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64]
)

if n_rows == 0:
raise ValueError("No rows in the input array. DBScan cannot be "
Expand All @@ -280,8 +286,13 @@ class DBSCAN(Base,
cdef uintptr_t sample_weight_ptr = <uintptr_t> NULL
if sample_weight is not None:
sample_weight_m, _, _, _ = \
input_to_cuml_array(sample_weight, check_dtype=self.dtype,
check_rows=n_rows, check_cols=1)
input_to_cuml_array(
sample_weight,
convert_to_dtype=(self.dtype if convert_dtype
else None),
check_dtype=self.dtype,
check_rows=n_rows,
check_cols=1)
sample_weight_ptr = sample_weight_m.ptr

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
Expand Down Expand Up @@ -411,7 +422,8 @@ class DBSCAN(Base,

@generate_docstring(skip_parameters_heading=True)
@enable_device_interop
def fit(self, X, out_dtype="int32", sample_weight=None) -> "DBSCAN":
def fit(self, X, out_dtype="int32", sample_weight=None,
convert_dtype=True) -> "DBSCAN":
"""
Perform DBSCAN clustering from features.

Expand Down
30 changes: 20 additions & 10 deletions python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -213,7 +213,8 @@ class KMeans(Base,
def __init__(self, *, handle=None, n_clusters=8, max_iter=300, tol=1e-4,
verbose=False, random_state=1,
init='scalable-k-means++', n_init=1, oversampling_factor=2.0,
max_samples_per_batch=1<<15, output_type=None):
max_samples_per_batch=1<<15, convert_dtype=True,
output_type=None):
super().__init__(handle=handle,
verbose=verbose,
output_type=output_type)
Expand Down Expand Up @@ -258,26 +259,35 @@ class KMeans(Base,
IF GPUBUILD == 1:
self._params_init = Array
self.cluster_centers_, _n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(init, order='C',
check_dtype=[np.float32, np.float64])
input_to_cuml_array(
init, order='C',
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64]
)

@generate_docstring()
@enable_device_interop
def fit(self, X, sample_weight=None) -> "KMeans":
def fit(self, X, sample_weight=None, convert_dtype=True) -> "KMeans":
"""
Compute k-means clustering with X.

"""
if self.init == 'preset':
check_cols = self.n_cols
check_dtype = self.dtype
target_dtype = self.dtype
else:
check_cols = False
check_dtype = [np.float32, np.float64]
target_dtype = np.float32

_X_m, _n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, order='C',
input_to_cuml_array(X,
order='C',
check_cols=check_cols,
convert_to_dtype=(target_dtype if convert_dtype
else None),
check_dtype=check_dtype)

IF GPUBUILD == 1:
Expand Down Expand Up @@ -400,7 +410,7 @@ class KMeans(Base,
"""
return self.fit(X, sample_weight=sample_weight).labels_

def _predict_labels_inertia(self, X, convert_dtype=False,
def _predict_labels_inertia(self, X, convert_dtype=True,
sample_weight=None,
normalize_weights=True
) -> typing.Tuple[CumlArray, float]:
Expand Down Expand Up @@ -539,7 +549,7 @@ class KMeans(Base,
'description': 'Cluster indexes',
'shape': '(n_samples, 1)'})
@enable_device_interop
def predict(self, X, convert_dtype=False, sample_weight=None,
def predict(self, X, convert_dtype=True, sample_weight=None,
normalize_weights=True) -> CumlArray:
"""
Predict the closest cluster each sample in X belongs to.
Expand All @@ -558,7 +568,7 @@ class KMeans(Base,
'description': 'Transformed data',
'shape': '(n_samples, n_clusters)'})
@enable_device_interop
def transform(self, X, convert_dtype=False) -> CumlArray:
def transform(self, X, convert_dtype=True) -> CumlArray:
"""
Transform X to a cluster-distance space.

Expand Down Expand Up @@ -674,4 +684,4 @@ class KMeans(Base,
return super().get_param_names() + \
['n_init', 'oversampling_factor', 'max_samples_per_batch',
'init', 'max_iter', 'n_clusters', 'random_state',
'tol']
'tol', "convert_dtype"]
7 changes: 5 additions & 2 deletions python/cuml/cuml/common/doc_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,7 +44,10 @@

_parameters_docstrings = {
"dense": "{name} : array-like (device or host) shape = {shape}\n"
" Dense matrix containing floats or doubles.\n"
" Dense matrix. If datatype is other than floats or doubles,\n"
" then the data will be converted to float which increases memory\n"
" utilization. Set the parameter convert_dtype to False to avoid \n"
" this, then the method will throw an error instead. \n"
" Acceptable formats: CUDA array interface compliant objects like\n"
" CuPy, cuDF DataFrame/Series, NumPy ndarray and Pandas\n"
" DataFrame/Series.",
Expand Down
9 changes: 6 additions & 3 deletions python/cuml/cuml/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(
self._hyperparams = ["n_components", "whiten", "copy", "batch_size"]
self._sparse_model = True

def fit(self, X, y=None) -> "IncrementalPCA":
def fit(self, X, y=None, convert_dtype=True) -> "IncrementalPCA":
"""
Fit the model with X, using minibatches of size batch_size.

Expand Down Expand Up @@ -250,7 +250,10 @@ def fit(self, X, y=None) -> "IncrementalPCA":
# transform and inverse transform convert the output to the
# required type.
X, n_samples, n_features, self.dtype = input_to_cupy_array(
X, order="K", check_dtype=[cp.float32, cp.float64]
X,
order="K",
convert_to_dtype=(cp.float32 if convert_dtype else None),
check_dtype=[cp.float32, cp.float64],
)

n_samples, n_features = X.shape
Expand Down
11 changes: 8 additions & 3 deletions python/cuml/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ class PCA(UniversalBase,

@generate_docstring(X='dense_sparse')
@enable_device_interop
def fit(self, X, y=None) -> "PCA":
def fit(self, X, y=None, convert_dtype=True) -> "PCA":
"""
Fit the model with X. y is currently ignored.

Expand All @@ -431,7 +431,10 @@ class PCA(UniversalBase,
return self._sparse_fit(X)

X_m, self.n_samples_, self.n_features_in_, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64])
cdef uintptr_t _input_ptr = X_m.ptr
self.feature_names_in_ = X_m.index

Expand Down Expand Up @@ -648,7 +651,7 @@ class PCA(UniversalBase,
'description': 'Transformed values',
'shape': '(n_samples, n_components)'})
@enable_device_interop
def transform(self, X, convert_dtype=False) -> CumlArray:
def transform(self, X, convert_dtype=True) -> CumlArray:
"""
Apply dimensionality reduction to X.

Expand All @@ -667,6 +670,8 @@ class PCA(UniversalBase,
elif self._sparse_model:
X, _, _, _ = \
input_to_cupy_array(X, order='K',
convert_to_dtype=(dtype if convert_dtype
else None),
check_dtype=[cp.float32, cp.float64])
return self._sparse_transform(X)

Expand Down
11 changes: 7 additions & 4 deletions python/cuml/cuml/decomposition/tsvd.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -316,14 +316,17 @@ class TruncatedSVD(UniversalBase,
'description': 'Reduced version of X',
'shape': '(n_samples, n_components)'})
@enable_device_interop
def fit_transform(self, X, y=None) -> CumlArray:
def fit_transform(self, X, y=None, convert_dtype=True) -> CumlArray:
"""
Fit LSI model to X and perform dimensionality reduction on X.
y is currently ignored.

"""
X_m, self.n_rows, self.n_features_in_, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64])
cdef uintptr_t _input_ptr = X_m.ptr

self._initialize_arrays(self.n_components, self.n_rows,
Expand Down Expand Up @@ -431,7 +434,7 @@ class TruncatedSVD(UniversalBase,
'description': 'Reduced version of X',
'shape': '(n_samples, n_components)'})
@enable_device_interop
def transform(self, X, convert_dtype=False) -> CumlArray:
def transform(self, X, convert_dtype=True) -> CumlArray:
"""
Perform dimensionality reduction on X.

Expand Down
5 changes: 4 additions & 1 deletion python/cuml/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ class BaseRandomForestModel(Base):
self._reset_forest_data()

X_m, self.n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64],
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64],
order='F')
if self.n_bins > self.n_rows:
warnings.warn("The number of bins, `n_bins` is greater than "
Expand Down
17 changes: 12 additions & 5 deletions python/cuml/cuml/experimental/linear_model/lars.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,22 @@ class Lars(Base, RegressorMixin):
" Proceeding without it.")
return Gram

def _fit_cpp(self, X, y, Gram, x_scale):
def _fit_cpp(self, X, y, Gram, x_scale, convert_dtype):
""" Fit lars model using cpp solver"""
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
X_m, _, _, _ = input_to_cuml_array(X, check_dtype=self.dtype,
order='F')
X_m, _, _, _ = \
input_to_cuml_array(X,
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=self.dtype,
order='F')
cdef uintptr_t X_ptr = X_m.ptr
cdef int n_rows = X.shape[0]
cdef uintptr_t y_ptr = \
input_to_cuml_array(y, check_dtype=self.dtype).array.ptr
input_to_cuml_array(y,
convert_to_dtype=(self.dtype if convert_dtype
else None),
check_dtype=self.dtype).array.ptr
cdef int max_iter = self.n_nonzero_coefs
self.beta_ = CumlArray.zeros(max_iter, dtype=self.dtype)
cdef uintptr_t beta_ptr = self.beta_.ptr
Expand Down Expand Up @@ -324,7 +331,7 @@ class Lars(Base, RegressorMixin):
if self.eps is None:
self.eps = np.finfo(float).eps

self._fit_cpp(X, y, Gram, x_scale)
self._fit_cpp(X, y, Gram, x_scale, convert_dtype)

self._set_intercept(x_mean, x_scale, y_scale)

Expand Down
29 changes: 19 additions & 10 deletions python/cuml/cuml/explainer/tree_shap.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -208,9 +208,9 @@ cdef class TreeExplainer:
cdef object num_class
cdef object data

def __init__(self, *, model, data=None):
def __init__(self, *, model, data=None, convert_dtype=True):
if data is not None:
self.data, _, _, _ = self._prepare_input(data)
self.data, _, _, _ = self._prepare_input(data, convert_dtype)
else:
self.data = None

Expand Down Expand Up @@ -260,22 +260,28 @@ cdef class TreeExplainer:
if len(self.num_class) > 1:
raise NotImplementedError("TreeExplainer does not support multi-target models")

def _prepare_input(self, X):
def _prepare_input(self, X, convert_dtype):
try:
return input_to_cuml_array(
X, order='C', check_dtype=[np.float32, np.float64])
X,
order='C',
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=[np.float32, np.float64])
except ValueError:
# input can be a DataFrame with mixed types
# in this case coerce to 64-bit
return input_to_cuml_array(
X, order='C', convert_to_dtype=np.float64)
X,
order='C',
convert_to_dtype=np.float64)

def _determine_output_type(self, X):
X_type = determine_array_type(X)
# Coerce to CuPy / NumPy because we may need to return 3D array
return 'numpy' if X_type == 'numpy' else 'cupy'

def shap_values(self, X) -> CumlArray:
def shap_values(self, X, convert_dtype=True) -> CumlArray:
"""
Estimate the SHAP values for a set of samples. For a given row, the
SHAP values plus the `expected_value` attribute sum up to the raw
Expand All @@ -296,7 +302,7 @@ cdef class TreeExplainer:
Returns a matrix of SHAP values of shape
(# classes x # samples x # features).
"""
X_m, n_rows, n_cols, dtype = self._prepare_input(X)
X_m, n_rows, n_cols, dtype = self._prepare_input(X, convert_dtype)
# Storing a C-order 3D array in a CumlArray leads to cryptic error
# ValueError: len(shape) != len(strides)
# So we use 2D array here
Expand Down Expand Up @@ -338,7 +344,10 @@ cdef class TreeExplainer:
return preds[:, :-1]

def shap_interaction_values(
self, X, method='shapley-interactions') -> CumlArray:
self,
X,
method='shapley-interactions',
convert_dtype=True) -> CumlArray:
"""
Estimate the SHAP interaction values for a set of samples. For a
given row, the SHAP values plus the `expected_value` attribute sum
Expand All @@ -363,7 +372,7 @@ cdef class TreeExplainer:
Returns a matrix of SHAP values of shape
(# classes x # samples x # features x # features).
"""
X_m, n_rows, n_cols, dtype = self._prepare_input(X)
X_m, n_rows, n_cols, dtype = self._prepare_input(X, convert_dtype)

# Storing a C-order 3D array in a CumlArray leads to cryptic error
# ValueError: len(shape) != len(strides)
Expand Down
Loading
Loading