Skip to content

Commit

Permalink
[fix] correct issue in aliased _onedal_cpu_supported and _onedal_gpu_…
Browse files Browse the repository at this point in the history
…supported in fit_check_before_support_check (#2124)

* Update test_common.py

* Update incremental_linear.py

* Update incremental_covariance.py

* Update k_means.py

* Update linear.py

* Update incremental_linear.py
  • Loading branch information
icfaust authored Oct 23, 2024
1 parent 2029a74 commit f350c0d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 39 deletions.
9 changes: 3 additions & 6 deletions sklearnex/cluster/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _onedal_predict_supported(self, method_name, *data):
@wrap_output_data
def predict(self, X):
self._validate_params()

check_is_fitted(self)
return dispatch(
self,
"predict",
Expand Down Expand Up @@ -280,7 +280,7 @@ def predict(
"will be removed in 1.5.",
FutureWarning,
)

check_is_fitted(self)
return dispatch(
self,
"predict",
Expand All @@ -293,8 +293,6 @@ def predict(
)

def _onedal_predict(self, X, sample_weight=None, queue=None):
check_is_fitted(self)

X = validate_data(
self,
X,
Expand Down Expand Up @@ -334,6 +332,7 @@ def transform(self, X):

@wrap_output_data
def score(self, X, y=None, sample_weight=None):
check_is_fitted(self)
return dispatch(
self,
"score",
Expand All @@ -347,8 +346,6 @@ def score(self, X, y=None, sample_weight=None):
)

def _onedal_score(self, X, y=None, sample_weight=None, queue=None):
check_is_fitted(self)

X = validate_data(
self,
X,
Expand Down
3 changes: 2 additions & 1 deletion sklearnex/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sklearn.covariance import EmpiricalCovariance as _sklearn_EmpiricalCovariance
from sklearn.covariance import log_likelihood
from sklearn.utils import check_array, gen_batches
from sklearn.utils.validation import _num_features
from sklearn.utils.validation import _num_features, check_is_fitted

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
Expand Down Expand Up @@ -226,6 +226,7 @@ def _onedal_partial_fit(self, X, queue=None, check_input=True):
def score(self, X_test, y=None):
xp, _ = get_namespace(X_test)

check_is_fitted(self)
location = self.location_
if sklearn_check_version("1.0"):
X = validate_data(
Expand Down
18 changes: 3 additions & 15 deletions sklearnex/linear_model/incremental_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import numpy as np
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
from sklearn.exceptions import NotFittedError
from sklearn.metrics import r2_score
from sklearn.utils import check_array, gen_batches
from sklearn.utils.validation import check_is_fitted

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
Expand Down Expand Up @@ -414,13 +414,7 @@ def predict(self, X, y=None):
C : array, shape (n_samples, n_targets)
Returns predicted values.
"""
if not hasattr(self, "coef_"):
msg = (
"This %(name)s instance is not fitted yet. Call 'fit' or 'partial_fit' "
"with appropriate arguments before using this estimator."
)
raise NotFittedError(msg % {"name": self.__class__.__name__})

check_is_fitted(self)
return dispatch(
self,
"predict",
Expand Down Expand Up @@ -472,13 +466,7 @@ def score(self, X, y, sample_weight=None):
regressors (except for
:class:`~sklearn.multioutput.MultiOutputRegressor`).
"""
if not hasattr(self, "coef_"):
msg = (
"This %(name)s instance is not fitted yet. Call 'fit' or 'partial_fit' "
"with appropriate arguments before using this estimator."
)
raise NotFittedError(msg % {"name": self.__class__.__name__})

check_is_fitted(self)
return dispatch(
self,
"score",
Expand Down
13 changes: 3 additions & 10 deletions sklearnex/linear_model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from abc import ABC

import numpy as np
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LinearRegression as _sklearn_LinearRegression
from sklearn.metrics import r2_score
from sklearn.utils.validation import check_array
Expand All @@ -33,7 +32,7 @@
from sklearn.linear_model._base import _deprecate_normalize

from scipy.sparse import issparse
from sklearn.utils.validation import check_X_y
from sklearn.utils.validation import check_is_fitted, check_X_y

from onedal.common.hyperparameters import get_hyperparameters
from onedal.linear_model import LinearRegression as onedal_LinearRegression
Expand Down Expand Up @@ -111,14 +110,7 @@ def fit(self, X, y, sample_weight=None):

@wrap_output_data
def predict(self, X):

if not hasattr(self, "coef_"):
msg = (
"This %(name)s instance is not fitted yet. Call 'fit' with "
"appropriate arguments before using this estimator."
)
raise NotFittedError(msg % {"name": self.__class__.__name__})

check_is_fitted(self)
return dispatch(
self,
"predict",
Expand All @@ -131,6 +123,7 @@ def predict(self, X):

@wrap_output_data
def score(self, X, y, sample_weight=None):
check_is_fitted(self)
return dispatch(
self,
"score",
Expand Down
9 changes: 2 additions & 7 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,9 @@ def runtime_property_check(text, estimator, method):

def fit_check_before_support_check(text, estimator, method):
if "fit" not in method:
if "_onedal_cpu_supported" in text["funcs"]:
onedal_support = "_onedal_cpu_supported"
elif "_onedal_gpu_supported" in text["funcs"]:
onedal_support = "_onedal_gpu_supported"
else:
if "dispatch" not in text["funcs"]:
pytest.skip(f"onedal dispatching not used in {estimator}.{method}")
# get location of _onedal_*_supported
idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index(onedal_support)
idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("dispatch")
validfuncs = text["funcs"][:idx]
assert (
"check_is_fitted" in validfuncs
Expand Down

0 comments on commit f350c0d

Please sign in to comment.