Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make KernelRidge inherit from UniversalBase #6327

Merged
merged 4 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ class UniversalBase(Base):
_experimental_dispatching = False

def import_cpu_model(self):
# skip the CPU estimator has been imported already
# skip if the CPU estimator has been imported already
if hasattr(self, '_cpu_model_class'):
return
if hasattr(self, '_cpu_estimator_import_path'):
Expand Down
17 changes: 14 additions & 3 deletions python/cuml/cuml/kernel_ridge/kernel_ridge.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,11 @@ from cuml.internals.safe_imports import gpu_only_import_from
from cuml.internals.safe_imports import gpu_only_import
from cupyx import lapack, geterr, seterr
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.base import Base
from cuml.internals.base import UniversalBase
from cuml.internals.api_decorators import (
device_interop_preparation,
enable_device_interop,
)
from cuml.internals.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_cuml_array
Expand Down Expand Up @@ -101,7 +105,7 @@ def _solve_cholesky_kernel(K, y, alpha, sample_weight=None):
return dual_coefs.T


class KernelRidge(Base, RegressorMixin):
class KernelRidge(UniversalBase, RegressorMixin):
"""
Kernel ridge regression (KRR) performs l2 regularised ridge regression
using the kernel trick. The kernel trick allows the estimator to learn a
Expand Down Expand Up @@ -203,7 +207,9 @@ class KernelRidge(Base, RegressorMixin):
"""

dual_coef_ = CumlArrayDescriptor()
_cpu_estimator_import_path = "sklearn.kernel_ridge.KernelRidge"

@device_interop_preparation
def __init__(
self,
*,
Expand All @@ -226,6 +232,9 @@ class KernelRidge(Base, RegressorMixin):
self.coef0 = coef0
self.kernel_params = kernel_params

def get_attr_names(self):
return ['dual_coef_', 'X_fit_']

@classmethod
def _get_param_names(cls):
return super()._get_param_names() + [
Expand All @@ -247,6 +256,7 @@ class KernelRidge(Base, RegressorMixin):
filter_params=True, **params)

@generate_docstring()
@enable_device_interop
def fit(self, X, y, sample_weight=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, small detail, but it is always nice to use the self.n_features_in_ attribute instead of local variable n_rows when ingesting data. This gives one more thing to replicate sklearn's behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean we should set self.n_features_in_ instead of using a local variable? Can do that. One question, n_features_in_ is the number of features, so shouldn't it be set to n_cols instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right meant the number of columns.

convert_dtype=True) -> "KernelRidge":

Expand Down Expand Up @@ -283,6 +293,7 @@ class KernelRidge(Base, RegressorMixin):
self.X_fit_ = X_m
return self

@enable_device_interop
def predict(self, X):
"""
Predict using the kernel ridge model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest
import numpy as np
import cupy as cp
from sklearn.datasets import make_classification, make_regression, make_blobs
from sklearn.linear_model import (
LinearRegression,
Expand All @@ -26,6 +27,7 @@
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.kernel_ridge import KernelRidge
from sklearn.manifold import TSNE
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import (
NearestNeighbors,
KNeighborsClassifier,
Expand Down Expand Up @@ -169,3 +171,25 @@ def test_proxy_facade():
proxy_value = getattr(PCA, attr)

assert original_value == proxy_value


def test_kernel_ridge():
rng = np.random.RandomState(42)

X = 5 * rng.rand(10000, 1)
y = np.sin(X).ravel()

kr = GridSearchCV(
KernelRidge(kernel="rbf", gamma=0.1),
param_grid={
"alpha": [1e0, 0.1, 1e-2, 1e-3],
"gamma": np.logspace(-2, 2, 5),
},
)
kr.fit(X, y)

y_pred = kr.predict(X)

assert not isinstance(
y_pred, cp.ndarray
), f"y_pred should be a np.ndarray, but is a {type(y_pred)}"
19 changes: 19 additions & 0 deletions python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from cuml.cluster import KMeans
from cuml.cluster import DBSCAN
from cuml.ensemble import RandomForestClassifier, RandomForestRegressor
from cuml.kernel_ridge import KernelRidge
from cuml.common.device_selection import DeviceType, using_device_type
from cuml.testing.utils import assert_dbscan_equal
from hdbscan import HDBSCAN as refHDBSCAN
Expand All @@ -49,6 +50,7 @@
from sklearn.linear_model import LinearRegression as skLinearRegression
from sklearn.decomposition import PCA as skPCA
from sklearn.decomposition import TruncatedSVD as skTruncatedSVD
from sklearn.kernel_ridge import KernelRidge as skKernelRidge
from sklearn.cluster import KMeans as skKMeans
from sklearn.cluster import DBSCAN as skDBSCAN
from sklearn.ensemble import RandomForestClassifier as skRFC
Expand Down Expand Up @@ -832,6 +834,23 @@ def test_elasticnet_methods(train_device, infer_device):
assert ref_output - tol <= output <= ref_output + tol


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_kernelridge_methods(train_device, infer_device):
ref_model = skKernelRidge()
ref_model.fit(X_train_reg, y_train_reg)
ref_output = ref_model.score(X_test_reg, y_test_reg)

model = KernelRidge()
with using_device_type(train_device):
model.fit(X_train_reg, y_train_reg)
with using_device_type(infer_device):
output = model.score(X_test_reg, y_test_reg)

tol = 0.01
assert ref_output - tol <= output <= ref_output + tol


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_ridge_methods(train_device, infer_device):
Expand Down