Skip to content

Commit

Permalink
add _check_sample_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Nov 22, 2024
1 parent edf0350 commit 4efad2c
Showing 1 changed file with 141 additions and 1 deletion.
142 changes: 141 additions & 1 deletion sklearnex/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,144 @@
# limitations under the License.
# ===============================================================================

from daal4py.sklearn.utils.validation import _assert_all_finite
import warnings

import numbers
import numpy as np
import scipy.sparse as sp
from sklearn.utils.validation import _num_samples, check_array, check_non_negative, _assert_all_finite as _sklearn_assert_all_finite

from daal4py.sklearn._utils import sklearn_check_version
from onedal.utils._array_api import _is_numpy_namespace, _get_sycl_namespace
from onedal.utils.validation import _assert_all_finite as _onedal_assert_all_finite

from ._array_api import get_namespace

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data as _sklearn_validate_data

_finite_keyword = "ensure_all_finite"

else:
from sklearn.base import BaseEstimator

_sklearn_validate_data = BaseEstimator._validate_data
_finite_keyword = "force_all_finite"


def _is_contiguous(X):
# array_api does not have a `strides` or `flags` attribute for testing memory
# order. When dlpack support is brought in for oneDAL, the dlpack python capsule
# can then be inspected for strides and this must be updated. _is_contiguous is
# therefore conservative in verifying attributes and does not support array_api.
# This will block onedal_assert_all_finite from being used for array_api inputs.
return hasattr(X, "flags") and (X.flags["C_CONTIGUOUS"] or X.flags["F_CONTIGUOUS"])


def _sklearnex_assert_all_finite(
X,
*,
allow_nan=False,
input_name="",
):
# size check is an initial match to daal4py for performance reasons, can be
# optimized later
xp, _ = get_namespace(X)
if X.size < 32768 or X.dtype not in [xp.float32, xp.float64] or not _is_contiguous(X):
_sklearn_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name)
else:
_onedal_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name)


def assert_all_finite(
X,
*,
allow_nan=False,
input_name="",
):
_sklearnex_assert_all_finite(
X.data if sp.issparse(X) else X,
allow_nan=allow_nan,
input_name=input_name,
)


def validate_data(
_estimator,
/,
X="no_validation",
y="no_validation",
**kwargs,
):
# force finite check to not occur in sklearn, default is True
# `ensure_all_finite` is the most up-to-date keyword name in sklearn
# _finite_keyword provides backward compatability for `force_all_finite`
ensure_all_finite = kwargs.pop("ensure_all_finite", True)
kwargs[_finite_keyword] = False

out = _sklearn_validate_data(
_estimator,
X=X,
y=y,
**kwargs,
)
if ensure_all_finite:
# run local finite check
allow_nan = ensure_all_finite == "allow-nan"
arg = iter(out if isinstance(out, tuple) else (out,))
if not isinstance(X, str) or X != "no_validation":
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X")
if not (y is None or isinstance(y, str) and y == "no_validation"):
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y")
return out


def _check_sample_weight(
sample_weight, X, dtype=None, copy=False, only_non_negative=False
):

n_samples = _num_samples(X)
xp, _ = get_namespace(X)

if dtype is not None and dtype not in [xp.float32, xp.float64]:
dtype = xp.float64

if sample_weight is None:
sample_weight = xp.ones(n_samples, dtype=dtype)
elif isinstance(sample_weight, numbers.Number):
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
else:
if dtype is None:
dtype = [xp.float64, xp.float32]

# create param dict such that the variable finite_keyword can
# be added to it without direct sklearn_check_version maintenance
params = {"accept_sparse":False,
"ensure_2d":False,
"dtype":dtype,
"order":"C",
"copy":copy,
"input_name":"sample_weight",
_finite_keyword:False,
}

sample_weight = check_array(
sample_weight,
**params
)
assert_all_finite(sample_weight, input_name="sample_weight")

if sample_weight.ndim != 1:
raise ValueError("Sample weights must be 1D array or scalar")

if sample_weight.shape != (n_samples,):
raise ValueError(
"sample_weight.shape == {}, expected {}!".format(
sample_weight.shape, (n_samples,)
)
)

if only_non_negative:
check_non_negative(sample_weight, "`sample_weight`")

return sample_weight

0 comments on commit 4efad2c

Please sign in to comment.