Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling CPU/GPU interop for SVM, DBSCAN and KMeans #6020

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions python/cuml/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ from cuml.internals.safe_imports import gpu_only_import
cp = gpu_only_import('cupy')

from cuml.internals.array import CumlArray
from cuml.internals.base import Base
from cuml.internals.base import UniversalBase
from cuml.common.doc_utils import generate_docstring
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.mixins import ClusterMixin
Expand Down Expand Up @@ -106,7 +106,7 @@ IF GPUBUILD == 1:
bool opg) except +


class DBSCAN(Base,
class DBSCAN(UniversalBase,
ClusterMixin,
CMajorInputTagMixin):
"""
Expand Down Expand Up @@ -222,8 +222,8 @@ class DBSCAN(Base,
"""

_cpu_estimator_import_path = 'sklearn.cluster.DBSCAN'
labels_ = CumlArrayDescriptor()
core_sample_indices_ = CumlArrayDescriptor()
core_sample_indices_ = CumlArrayDescriptor(order="C")
labels_ = CumlArrayDescriptor(order="C")

@device_interop_preparation
def __init__(self, *,
Expand Down Expand Up @@ -268,7 +268,7 @@ class DBSCAN(Base,
"np.int32, np.int64}")

IF GPUBUILD == 1:
X_m, n_rows, n_cols, self.dtype = \
X_m, n_rows, self.n_features_in_, self.dtype = \
input_to_cuml_array(
X,
order='C',
Expand Down Expand Up @@ -338,7 +338,7 @@ class DBSCAN(Base,
fit(handle_[0],
<float*>input_ptr,
<int> n_rows,
<int> n_cols,
<int> self.n_features_in_,
<float> self.eps,
<int> self.min_samples,
<DistanceType> metric,
Expand All @@ -353,7 +353,7 @@ class DBSCAN(Base,
fit(handle_[0],
<float*>input_ptr,
<int64_t> n_rows,
<int64_t> n_cols,
<int64_t> self.n_features_in_,
<float> self.eps,
<int> self.min_samples,
<DistanceType> metric,
Expand All @@ -370,7 +370,7 @@ class DBSCAN(Base,
fit(handle_[0],
<double*>input_ptr,
<int> n_rows,
<int> n_cols,
<int> self.n_features_in_,
<double> self.eps,
<int> self.min_samples,
<DistanceType> metric,
Expand All @@ -385,7 +385,7 @@ class DBSCAN(Base,
fit(handle_[0],
<double*>input_ptr,
<int64_t> n_rows,
<int64_t> n_cols,
<int64_t> self.n_features_in_,
<double> self.eps,
<int> self.min_samples,
<DistanceType> metric,
Expand Down Expand Up @@ -475,3 +475,6 @@ class DBSCAN(Base,
"metric",
"algorithm",
]

def get_attr_names(self):
return ["core_sample_indices_", "labels_", "n_features_in_"]
57 changes: 34 additions & 23 deletions python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,18 @@ IF GPUBUILD == 1:

from cuml.internals.array import CumlArray
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.base import Base
from cuml.internals.base import UniversalBase
from cuml.common.doc_utils import generate_docstring
from cuml.internals.mixins import ClusterMixin
from cuml.internals.mixins import CMajorInputTagMixin
from cuml.common import input_to_cuml_array
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop

from sklearn.utils._openmp_helpers import _openmp_effective_n_threads

class KMeans(Base,

class KMeans(UniversalBase,
ClusterMixin,
CMajorInputTagMixin):

Expand Down Expand Up @@ -188,8 +190,8 @@ class KMeans(Base,
"""

_cpu_estimator_import_path = 'sklearn.cluster.KMeans'
labels_ = CumlArrayDescriptor()
cluster_centers_ = CumlArrayDescriptor()
labels_ = CumlArrayDescriptor(order='C')
cluster_centers_ = CumlArrayDescriptor(order='C')

def _get_kmeans_params(self):
IF GPUBUILD == 1:
Expand Down Expand Up @@ -232,6 +234,9 @@ class KMeans(Base,
self.labels_ = None
self.cluster_centers_ = None

# For sklearn interoperability
self._n_threads = _openmp_effective_n_threads()

# cuPy does not allow comparing with string. See issue #2372
init_str = init if isinstance(init, str) else None

Expand All @@ -258,7 +263,7 @@ class KMeans(Base,

IF GPUBUILD == 1:
self._params_init = Array
self.cluster_centers_, _n_rows, self.n_cols, self.dtype = \
self.cluster_centers_, _n_rows, self.n_features_in_, self.dtype = \
input_to_cuml_array(
init, order='C',
convert_to_dtype=(np.float32 if convert_dtype
Expand All @@ -274,15 +279,15 @@ class KMeans(Base,

"""
if self.init == 'preset':
check_cols = self.n_cols
check_cols = self.n_features_in_
check_dtype = self.dtype
target_dtype = self.dtype
else:
check_cols = False
check_dtype = [np.float32, np.float64]
target_dtype = np.float32

_X_m, _n_rows, self.n_cols, self.dtype = \
_X_m, _n_rows, self.n_features_in_, self.dtype = \
input_to_cuml_array(X,
order='C',
check_cols=check_cols,
Expand All @@ -306,14 +311,14 @@ class KMeans(Base,

cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr

int_dtype = np.int32 if np.int64(_n_rows) * np.int64(self.n_cols) < 2**31-1 else np.int64
int_dtype = np.int32 if np.int64(_n_rows) * np.int64(self.n_features_in_) < 2**31-1 else np.int64

self.labels_ = CumlArray.zeros(shape=_n_rows, dtype=int_dtype)
cdef uintptr_t labels_ptr = self.labels_.ptr

if (self.init in ['scalable-k-means++', 'k-means||', 'random']):
self.cluster_centers_ = \
CumlArray.zeros(shape=(self.n_clusters, self.n_cols),
CumlArray.zeros(shape=(self.n_clusters, self.n_features_in_),
dtype=self.dtype, order='C')

cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr
Expand All @@ -334,7 +339,7 @@ class KMeans(Base,
<KMeansParams> deref(params),
<const float*> input_ptr,
<int> _n_rows,
<int> self.n_cols,
<int> self.n_features_in_,
<const float *>sample_weight_ptr,
<float*> cluster_centers_ptr,
<int*> labels_ptr,
Expand All @@ -347,7 +352,7 @@ class KMeans(Base,
<KMeansParams> deref(params),
<const float*> input_ptr,
<int64_t> _n_rows,
<int64_t> self.n_cols,
<int64_t> self.n_features_in_,
<const float *>sample_weight_ptr,
<float*> cluster_centers_ptr,
<int64_t*> labels_ptr,
Expand All @@ -364,7 +369,7 @@ class KMeans(Base,
<KMeansParams> deref(params),
<const double*> input_ptr,
<int> _n_rows,
<int> self.n_cols,
<int> self.n_features_in_,
<const double *>sample_weight_ptr,
<double*> cluster_centers_ptr,
<int*> labels_ptr,
Expand All @@ -378,7 +383,7 @@ class KMeans(Base,
<KMeansParams> deref(params),
<const double*> input_ptr,
<int64_t> _n_rows,
<int64_t> self.n_cols,
<int64_t> self.n_features_in_,
<const double *>sample_weight_ptr,
<double*> cluster_centers_ptr,
<int64_t*> labels_ptr,
Expand Down Expand Up @@ -442,11 +447,13 @@ class KMeans(Base,
Sum of squared distances of samples to their closest cluster center.
"""

self.dtype = self.cluster_centers_.dtype

_X_m, _n_rows, _n_cols, _ = \
input_to_cuml_array(X, order='C', check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
else None),
check_cols=self.n_cols)
check_cols=self.n_features_in_)

IF GPUBUILD == 1:
cdef uintptr_t input_ptr = _X_m.ptr
Expand Down Expand Up @@ -486,7 +493,7 @@ class KMeans(Base,
<float*> cluster_centers_ptr,
<float*> input_ptr,
<size_t> _n_rows,
<size_t> self.n_cols,
<size_t> self.n_features_in_,
<float *>sample_weight_ptr,
<bool> normalize_weights,
<int*> labels_ptr,
Expand All @@ -498,7 +505,7 @@ class KMeans(Base,
<float*> cluster_centers_ptr,
<float*> input_ptr,
<int64_t> _n_rows,
<int64_t> self.n_cols,
<int64_t> self.n_features_in_,
<float *>sample_weight_ptr,
<bool> normalize_weights,
<int64_t*> labels_ptr,
Expand All @@ -513,7 +520,7 @@ class KMeans(Base,
<double*> cluster_centers_ptr,
<double*> input_ptr,
<size_t> _n_rows,
<size_t> self.n_cols,
<size_t> self.n_features_in_,
<double *>sample_weight_ptr,
<bool> normalize_weights,
<int*> labels_ptr,
Expand All @@ -525,7 +532,7 @@ class KMeans(Base,
<double*> cluster_centers_ptr,
<double*> input_ptr,
<int64_t> _n_rows,
<int64_t> self.n_cols,
<int64_t> self.n_features_in_,
<double *>sample_weight_ptr,
<bool> normalize_weights,
<int64_t*> labels_ptr,
Expand Down Expand Up @@ -578,7 +585,7 @@ class KMeans(Base,
input_to_cuml_array(X, order='C', check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
else None),
check_cols=self.n_cols)
check_cols=self.n_features_in_)
IF GPUBUILD == 1:
cdef uintptr_t input_ptr = _X_m.ptr

Expand Down Expand Up @@ -607,7 +614,7 @@ class KMeans(Base,
<float*> cluster_centers_ptr,
<float*> input_ptr,
<int> _n_rows,
<int> self.n_cols,
<int> self.n_features_in_,
<float*> preds_ptr)
else:
cpp_transform(
Expand All @@ -616,7 +623,7 @@ class KMeans(Base,
<float*> cluster_centers_ptr,
<float*> input_ptr,
<int64_t> _n_rows,
<int64_t> self.n_cols,
<int64_t> self.n_features_in_,
<float*> preds_ptr)

elif self.dtype == np.float64:
Expand All @@ -627,7 +634,7 @@ class KMeans(Base,
<double*> cluster_centers_ptr,
<double*> input_ptr,
<int> _n_rows,
<int> self.n_cols,
<int> self.n_features_in_,
<double*> preds_ptr)
else:
cpp_transform(
Expand All @@ -636,7 +643,7 @@ class KMeans(Base,
<double*> cluster_centers_ptr,
<double*> input_ptr,
<int64_t> _n_rows,
<int64_t> self.n_cols,
<int64_t> self.n_features_in_,
<double*> preds_ptr)

else:
Expand Down Expand Up @@ -685,3 +692,7 @@ class KMeans(Base,
['n_init', 'oversampling_factor', 'max_samples_per_batch',
'init', 'max_iter', 'n_clusters', 'random_state',
'tol', "convert_dtype"]

def get_attr_names(self):
return ['cluster_centers_', 'labels_', 'inertia_',
'n_iter_', 'n_features_in_', '_n_threads']
51 changes: 50 additions & 1 deletion python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
from cuml.internals.memory_utils import using_memory_type
from cuml.internals.mem_type import MemoryType
from cuml.decomposition import PCA, TruncatedSVD
from cuml.cluster import KMeans
from cuml.cluster import DBSCAN
from cuml.common.device_selection import DeviceType, using_device_type
from cuml.testing.utils import assert_dbscan_equal
from hdbscan import HDBSCAN as refHDBSCAN
from sklearn.neighbors import NearestNeighbors as skNearestNeighbors
from sklearn.linear_model import Ridge as skRidge
Expand All @@ -42,6 +45,8 @@
from sklearn.linear_model import LinearRegression as skLinearRegression
from sklearn.decomposition import PCA as skPCA
from sklearn.decomposition import TruncatedSVD as skTruncatedSVD
from sklearn.cluster import KMeans as skKMeans
from sklearn.cluster import DBSCAN as skDBSCAN
from sklearn.datasets import make_regression, make_blobs
from pytest_cases import fixture_union, fixture
from importlib import import_module
Expand Down Expand Up @@ -136,7 +141,11 @@ def make_reg_dataset():

def make_blob_dataset():
X, y = make_blobs(
n_samples=2000, n_features=20, centers=20, random_state=0
n_samples=2000,
n_features=20,
centers=20,
random_state=0,
cluster_std=1.0,
)
X_train, X_test = X[:1800], X[1800:]
y_train, _ = y[:1800], y[1800:]
Expand Down Expand Up @@ -948,3 +957,43 @@ def test_hdbscan_methods(train_device, infer_device):
assert_membership_vectors(membership, ref_membership)
assert adjusted_rand_score(labels, ref_labels) >= 0.98
assert array_equal(probs, ref_probs, unit_tol=0.001, total_tol=0.006)


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_kmeans_methods(train_device, infer_device):
n_clusters = 20
ref_model = skKMeans(n_clusters=n_clusters)
ref_model.fit(X_train_blob)
ref_output = ref_model.predict(X_test_blob)

model = KMeans(n_clusters=n_clusters)
with using_device_type(train_device):
model.fit(X_train_blob)
with using_device_type(infer_device):
output = model.predict(X_test_blob)

assert adjusted_rand_score(ref_output, output) >= 0.9


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_dbscan_methods(train_device, infer_device):
eps = 8.0
ref_model = skDBSCAN(eps=eps)
ref_model.fit(X_train_blob)
ref_output = ref_model.fit_predict(X_train_blob)

model = DBSCAN(eps=eps)
with using_device_type(train_device):
model.fit(X_train_blob)
with using_device_type(infer_device):
output = model.fit_predict(X_train_blob)

assert array_equal(
ref_model.core_sample_indices_, ref_model.core_sample_indices_
)
assert adjusted_rand_score(ref_output, output) >= 0.95
assert_dbscan_equal(
ref_output, output, X_train_blob, model.core_sample_indices_, eps
)
Loading