Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for n_init=auto in KMeans #6142

Open
wants to merge 14 commits into
base: branch-25.02
Choose a base branch
from
32 changes: 28 additions & 4 deletions python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

# distutils: language = c++

import warnings

from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
from cuml.internals.safe_imports import gpu_only_import
Expand Down Expand Up @@ -92,7 +94,7 @@ class KMeans(UniversalBase,
3 4.0 3.0
>>>
>>> # Calling fit
>>> kmeans_float = KMeans(n_clusters=2)
>>> kmeans_float = KMeans(n_clusters=2, n_init="auto")
>>> kmeans_float.fit(b)
KMeans()
>>>
Expand Down Expand Up @@ -140,10 +142,17 @@ class KMeans(UniversalBase,
- If an ndarray is passed, it should be of
shape (`n_clusters`, `n_features`) and gives the initial centers.

n_init: int (default = 1)
n_init: 'auto' or int (default = 1)
Number of instances the k-means algorithm will be called with
different seeds. The final results will be from the instance
that produces lowest inertia out of n_init instances.

.. versionadded:: 24.12
Added 'auto' option for `n_init`.

.. versionchanged:: 25.02
Default value for `n_init` will change from 1 to `'auto'` in version 25.02.

oversampling_factor : float64 (default = 2.0)
The amount of points to sample
in scalable k-means++ initialization for potential centroids.
Expand Down Expand Up @@ -210,15 +219,30 @@ class KMeans(UniversalBase,
params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501
params.batch_samples = <int>self.max_samples_per_batch
params.oversampling_factor = <double>self.oversampling_factor
params.n_init = <int>self.n_init
n_init = self.n_init
if n_init == "warn":
warnings.warn(
"The default value of `n_init` will change from"
" 1 to 'auto' in 25.02. Set the value of `n_init`"
" explicitly to suppress this warning.",
FutureWarning,
)
n_init = 1
if n_init == "auto":
if self.init in ("k-means||", "scalable-k-means++"):
params.n_init = 1
else:
params.n_init = 10
else:
params.n_init = <int>n_init
return <size_t>params
ELSE:
return None

@device_interop_preparation
def __init__(self, *, handle=None, n_clusters=8, max_iter=300, tol=1e-4,
verbose=False, random_state=1,
init='scalable-k-means++', n_init=1, oversampling_factor=2.0,
init='scalable-k-means++', n_init="warn", oversampling_factor=2.0,
max_samples_per_batch=1<<15, convert_dtype=True,
output_type=None):
super().__init__(handle=handle,
Expand Down
8 changes: 7 additions & 1 deletion python/cuml/cuml/experimental/accel/estimator_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,13 @@ def __init__(self, *args, **kwargs):
self._cpu_model_class = (
original_class_a # Store a reference to the original class
)
kwargs, self._gpuaccel = self._hyperparam_translator(**kwargs)
sklearn_args = inspect.signature(self._cpu_model_class)
sklearn_args = sklearn_args.bind(*args, **kwargs)
sklearn_args.apply_defaults()
Comment on lines +210 to +212
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we use the constructor arguments and their default values of the scikit-learn class, combine them with what the user passed in and then feed it to the hyperparameter translator. This makes a difference for cases where the default values in scikit-learn and cuml are different and the user does not explicitly pass that argument.

# The scikit-learn class
class SkTimMeans:
  def __init__(self, foo='bar'):
    ...

# The cuml class
class CuTimMeans:
  def __init__(self, foo='baz'):
    ...

# User code
est = SkTimMeans()

foo should be set to 'bar' in est because that is the default value of the scikit-learn class.

It also fixes the many deprecation warnings in the accelerator tests we were getting for KMeans due to the main change of this PR.


kwargs, self._gpuaccel = self._hyperparam_translator(
**sklearn_args.arguments
)
super().__init__(*args, **kwargs)

self._cpu_hyperparams = list(
Expand Down
7 changes: 5 additions & 2 deletions python/cuml/cuml/explainer/sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -91,7 +91,10 @@ def kmeans_sampling(X, k, round_values=True, detailed=False, random_state=0):
X = imp.fit_transform(X)

kmeans = KMeans(
n_clusters=k, random_state=random_state, output_type=_output_dtype_str
n_clusters=k,
random_state=random_state,
output_type=_output_dtype_str,
n_init="auto",
).fit(X)

if round_values:
Expand Down
25 changes: 19 additions & 6 deletions python/cuml/cuml/tests/dask/test_dask_kmeans.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -63,7 +63,10 @@ def test_end_to_end(
X_train, y_train = X, y

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

cumlModel.fit(X_train)
Expand Down Expand Up @@ -120,7 +123,7 @@ def test_large_data_no_overflow(nrows_per_part, ncols, nclusters, client):
X.compute_chunk_sizes().persist()

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||", n_clusters=nclusters, random_state=10, n_init="auto"
)

cumlModel.fit(X_train)
Expand Down Expand Up @@ -171,7 +174,11 @@ def test_weighted_kmeans(nrows, ncols, nclusters, n_parts, client):
wt[cp.argmax(cp.array(y.compute()) == i).item()] = 5000.0

cumlModel = cumlKMeans(
verbose=0, init="k-means||", n_clusters=nclusters, random_state=10
verbose=0,
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

chunk_parts = int(nrows / n_parts)
Expand Down Expand Up @@ -237,7 +244,10 @@ def test_transform(nrows, ncols, nclusters, n_parts, input_type, client):
labels = cp.squeeze(y_train.compute())

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

cumlModel.fit(X_train)
Expand Down Expand Up @@ -302,7 +312,10 @@ def test_score(nrows, ncols, nclusters, n_parts, input_type, client):
X_train, y_train = X, y

cumlModel = cumlKMeans(
init="k-means||", n_clusters=nclusters, random_state=10
init="k-means||",
n_clusters=nclusters,
random_state=10,
n_init="auto",
)

cumlModel.fit(X_train)
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/cuml/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def test_fit_function(dataset, model_name):
model = models[model_name](np.random.normal(0.0, 1.0, (10,)))
elif model_name in ["RandomForestClassifier", "RandomForestRegressor"]:
model = models[model_name](n_bins=32)
elif model_name == "KMeans":
model = models[model_name](n_init="auto")
else:
if n_pos_args_constr == 1:
model = models[model_name]()
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def test_kmeans_methods(train_device, infer_device):
ref_model.fit(X_train_blob)
ref_output = ref_model.predict(X_test_blob)

model = KMeans(n_clusters=n_clusters)
model = KMeans(n_clusters=n_clusters, n_init="auto")
with using_device_type(train_device):
model.fit(X_train_blob)
with using_device_type(infer_device):
Expand Down
5 changes: 4 additions & 1 deletion python/cuml/cuml/tests/test_input_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def test_estimators_all_dtypes(model_name, dtype):

X_train, y_train, X_test = make_dataset(dtype, nrows, ncols, ninfo)
print(model_name)
model = models[model_name]()
if model_name == "KMeans":
model = models[model_name](n_init="auto")
else:
model = models[model_name]()
sign = inspect.signature(model.fit)
if "y" in sign.parameters:
model.fit(X=X_train, y=y_train)
Expand Down
29 changes: 27 additions & 2 deletions python/cuml/cuml/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ def random_state():
return random_state


def test_n_init_deprecation():
X, y = make_blobs(
random_state=0,
)

# Warn about default changing
kmeans = cuml.KMeans()
with pytest.warns(
FutureWarning, match="The default value of `n_init` will change from"
):
kmeans.fit(X)

# No warning when explicitly set to integer or 'auto'
for n_init in ("auto", 2):
kmeans = cuml.KMeans(n_init=n_init)
kmeans.fit(X)


@pytest.mark.xfail
def test_n_init_cluster_consistency(random_state):

Expand Down Expand Up @@ -127,7 +145,9 @@ def test_traditional_kmeans_plus_plus_init(
cuml_kmeans.fit(X)
cu_score = cuml_kmeans.score(X)

kmeans = cluster.KMeans(random_state=random_state, n_clusters=nclusters)
kmeans = cluster.KMeans(
random_state=random_state, n_clusters=nclusters, n_init=1
)
kmeans.fit(cp.asnumpy(X))
sk_score = kmeans.score(cp.asnumpy(X))

Expand Down Expand Up @@ -167,7 +187,9 @@ def test_weighted_kmeans(nrows, ncols, nclusters, max_weight, random_state):
cuml_kmeans.fit(X, sample_weight=wt)
cu_score = cuml_kmeans.score(X)

sk_kmeans = cluster.KMeans(random_state=random_state, n_clusters=nclusters)
sk_kmeans = cluster.KMeans(
random_state=random_state, n_clusters=nclusters, n_init=1
)
sk_kmeans.fit(cp.asnumpy(X), sample_weight=wt)
sk_score = sk_kmeans.score(cp.asnumpy(X))

Expand Down Expand Up @@ -196,6 +218,7 @@ def test_kmeans_clusters_blobs(
n_clusters=nclusters,
random_state=random_state,
output_type="numpy",
n_init=1,
)

preds = cuml_kmeans.fit_predict(X)
Expand Down Expand Up @@ -327,6 +350,7 @@ def test_all_kmeans_params(
oversampling_factor=oversampling_factor,
max_samples_per_batch=max_samples_per_batch,
output_type="cupy",
n_init=1,
)

cuml_kmeans.fit_predict(X)
Expand Down Expand Up @@ -355,6 +379,7 @@ def test_score(nrows, ncols, nclusters, random_state):
n_clusters=nclusters,
random_state=random_state,
output_type="numpy",
n_init=1,
)

cuml_kmeans.fit(X)
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_rand_index_score(name, nrows):
params = default_base.copy()
params.update(pat[1])

cuml_kmeans = cuml.KMeans(n_clusters=params["n_clusters"])
cuml_kmeans = cuml.KMeans(n_clusters=params["n_clusters"], n_init="auto")

X, y = pat[0]

Expand Down
5 changes: 4 additions & 1 deletion python/cuml/cuml/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def test_cluster_pickle(tmpdir, datatype, keys, data_size):
def create_mod():
nrows, ncols, n_info = data_size
X_train, y_train, X_test = make_dataset(datatype, nrows, ncols, n_info)
model = cluster_models[keys]()
if keys == "KMeans":
model = cluster_models[keys](n_init="auto")
else:
model = cluster_models[keys]()
model.fit(X_train)
result["cluster"] = model.predict(X_test)
return model, X_test
Expand Down
Loading