diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index 9cbaa5284d3a2..c8269fc46d13a 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -45,8 +45,8 @@ LogLink, MultinomialLogit, ) +from sklearn.externals import array_api_extra as xpx from sklearn.utils import check_scalar -from sklearn.utils.stats import _weighted_percentile # Note: The shape of raw_prediction for multiclass classifications are @@ -588,7 +588,13 @@ def fit_intercept_only(self, y_true, sample_weight=None): if sample_weight is None: return np.median(y_true, axis=0) else: - return _weighted_percentile(y_true, sample_weight, 50) + return xpx.quantile( + y_true, + 0.5, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", + ) class PinballLoss(BaseLoss): @@ -646,12 +652,10 @@ def fit_intercept_only(self, y_true, sample_weight=None): This is the weighted median of the target, i.e. over the samples axis=0. """ - if sample_weight is None: - return np.percentile(y_true, 100 * self.closs.quantile, axis=0) - else: - return _weighted_percentile( - y_true, sample_weight, 100 * self.closs.quantile - ) + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + return xpx.quantile( + y_true, self.closs.quantile, axis=0, method=method, weights=sample_weight + ) class HuberLoss(BaseLoss): @@ -718,10 +722,15 @@ def fit_intercept_only(self, y_true, sample_weight=None): # not to the residual y_true - raw_prediction. An estimator like # HistGradientBoostingRegressor might then call it on the residual, e.g. # fit_intercept_only(y_true - raw_prediction). - if sample_weight is None: - median = np.percentile(y_true, 50, axis=0) - else: - median = _weighted_percentile(y_true, sample_weight, 50) + + method = "linear" if sample_weight is None else "inverted_cdf" + # XXX: it would be better to use method "averaged_inverted_cdf" + # for the weighted case + # (otherwise passing 1s weights is not equivalent to no weights) + # but this would break this test: + # ensemble/tests/test_gradient_boosting.py::test_huber_exact_backward_compat + + median = xpx.quantile(y_true, 0.5, axis=0, method=method, weights=sample_weight) diff = y_true - median term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff)) return median + np.average(term, weights=sample_weight) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index eaadc80cd503a..4ac7913deb958 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -1294,7 +1294,7 @@ def calibration_curve( if strategy == "quantile": # Determine bin edges by distribution of data quantiles = np.linspace(0, 1, n_bins + 1) - bins = np.percentile(y_prob, quantiles * 100) + bins = xpx.quantile(y_prob, quantiles) elif strategy == "uniform": bins = np.linspace(0.0, 1.0, n_bins + 1) else: diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 2eab0e53e2aa6..6eee159c8257c 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -16,11 +16,11 @@ RegressorMixin, _fit_context, ) +from sklearn.externals import array_api_extra as xpx from sklearn.utils import check_random_state from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils.multiclass import class_distribution from sklearn.utils.random import _random_choice_csc -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_sample_weight, _num_samples, @@ -581,10 +581,13 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: self.constant_ = np.median(y, axis=0) else: - self.constant_ = [ - _weighted_percentile(y[:, k], sample_weight, percentile_rank=50.0) - for k in range(self.n_outputs_) - ] + self.constant_ = xpx.quantile( + y, + 0.5, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", + ) elif self.strategy == "quantile": if self.quantile is None: @@ -592,16 +595,10 @@ def fit(self, X, y, sample_weight=None): "When using `strategy='quantile', you have to specify the desired " "quantile in the range [0, 1]." ) - percentile_rank = self.quantile * 100.0 - if sample_weight is None: - self.constant_ = np.percentile(y, axis=0, q=percentile_rank) - else: - self.constant_ = [ - _weighted_percentile( - y[:, k], sample_weight, percentile_rank=percentile_rank - ) - for k in range(self.n_outputs_) - ] + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + self.constant_ = xpx.quantile( + y, float(self.quantile), axis=0, weights=sample_weight, method=method + ) elif self.strategy == "constant": if self.constant is None: diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index e64763123f270..d48a51d1cdf3a 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -47,6 +47,7 @@ predict_stages, ) from sklearn.exceptions import NotFittedError +from sklearn.externals import array_api_extra as xpx from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from sklearn.tree import DecisionTreeRegressor @@ -54,7 +55,6 @@ from sklearn.utils import check_array, check_random_state, column_or_1d from sklearn.utils._param_validation import HasMethods, Interval, StrOptions from sklearn.utils.multiclass import check_classification_targets -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_sample_weight, check_is_fitted, @@ -275,7 +275,11 @@ def set_huber_delta(loss, y_true, raw_prediction, sample_weight=None): """Calculate and set self.closs.delta based on self.quantile.""" abserr = np.abs(y_true - raw_prediction.squeeze()) # sample_weight is always a ndarray, never None. - delta = _weighted_percentile(abserr, sample_weight, 100 * loss.quantile) + delta = xpx.quantile( + abserr, loss.quantile, axis=0, weights=sample_weight, method="inverted_cdf" + ) + # XXX: it would probably be better to use method "averaged_inverted_cdf" + # see explanations of why we can't in HuberLoss.fit_intercept_only loss.closs.delta = float(delta) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 955014484fc5d..32ae59bc408e5 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -23,10 +23,10 @@ get_namespace, get_namespace_and_device, size, + xpx, ) from sklearn.utils._array_api import _xlogy as xlogy from sklearn.utils._param_validation import Interval, StrOptions, validate_params -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_sample_weight, _num_samples, @@ -923,8 +923,12 @@ def median_absolute_error( if sample_weight is None: output_errors = _median(xp.abs(y_pred - y_true), axis=0) else: - output_errors = _weighted_percentile( - xp.abs(y_pred - y_true), sample_weight=sample_weight, average=True + output_errors = xpx.quantile( + xp.abs(y_pred - y_true), + 0.5, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", ) if isinstance(multioutput, str): if multioutput == "raw_values": @@ -1820,17 +1824,11 @@ def d2_pinball_score( multioutput="raw_values", ) - if sample_weight is None: - y_quantile = np.tile( - np.percentile(y_true, q=alpha * 100, axis=0), (len(y_true), 1) - ) - else: - y_quantile = np.tile( - _weighted_percentile( - y_true, sample_weight=sample_weight, percentile_rank=alpha * 100 - ), - (len(y_true), 1), - ) + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + y_quantile = np.tile( + xpx.quantile(y_true, alpha, axis=0, weights=sample_weight, method=method), + (len(y_true), 1), + ) denominator = mean_pinball_loss( y_true, diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index 5ab6fdd4b6576..a3cd99509cfe3 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -8,10 +8,10 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin, _fit_context +from sklearn.externals import array_api_extra as xpx from sklearn.preprocessing._encoders import OneHotEncoder from sklearn.utils import resample from sklearn.utils._param_validation import Interval, Options, StrOptions -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_feature_names_in, _check_sample_weight, @@ -350,39 +350,13 @@ def fit(self, X, y=None, sample_weight=None): bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1) elif self.strategy == "quantile": - percentile_levels = np.linspace(0, 100, n_bins[jj] + 1) - - # method="linear" is the implicit default for any numpy - # version. So we keep it version independent in that case by - # using an empty param dict. - percentile_kwargs = {} - if quantile_method != "linear" and sample_weight is None: - percentile_kwargs["method"] = quantile_method - - if sample_weight is None: - bin_edges[jj] = np.asarray( - np.percentile(column, percentile_levels, **percentile_kwargs), - dtype=np.float64, - ) - else: - # TODO: make _weighted_percentile accept an array of - # quantiles instead of calling it multiple times and - # sorting the column multiple times as a result. - average = ( - True if quantile_method == "averaged_inverted_cdf" else False - ) - bin_edges[jj] = np.asarray( - [ - _weighted_percentile( - column, - sample_weight, - percentile_rank=p, - average=average, - ) - for p in percentile_levels - ], - dtype=np.float64, - ) + quantile_levels = np.linspace(0, 1, n_bins[jj] + 1) + bin_edges[jj] = xpx.quantile( + column, + quantile_levels, + weights=sample_weight, + method=quantile_method, + ) elif self.strategy == "kmeans": from sklearn.cluster import KMeans # fixes import loops diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py index acc2aa1138b68..6eee6e5d00bb9 100644 --- a/sklearn/preprocessing/_polynomial.py +++ b/sklearn/preprocessing/_polynomial.py @@ -16,6 +16,7 @@ from scipy.special import comb from sklearn.base import BaseEstimator, TransformerMixin, _fit_context +from sklearn.externals import array_api_extra as xpx from sklearn.preprocessing._csr_polynomial_expansion import ( _calc_expanded_nnz, _calc_total_nnz, @@ -30,7 +31,6 @@ from sklearn.utils._mask import _get_mask from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils.fixes import parse_version, sp_version -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( FLOAT_DTYPES, _check_feature_names_in, @@ -784,18 +784,17 @@ def _get_base_knot_positions(X, n_knots=10, knots="uniform", sample_weight=None) Knot positions (points) of base interval. """ if knots == "quantile": - percentile_ranks = 100 * np.linspace( - start=0, stop=1, num=n_knots, dtype=np.float64 - ) + quantile_ranks = np.linspace(start=0, stop=1, num=n_knots, dtype=np.float64) if sample_weight is None: - knots = np.nanpercentile(X, percentile_ranks, axis=0) + knots = np.nanquantile(X, quantile_ranks, axis=0) else: - knots = np.array( - [ - _weighted_percentile(X, sample_weight, percentile_rank) - for percentile_rank in percentile_ranks - ] + knots = xpx.quantile( + X, + quantile_ranks, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", ) else: diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 61f1803b7a24f..995cf468273ee 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -7,13 +7,13 @@ from sklearn.base import clone from sklearn.dummy import DummyClassifier, DummyRegressor from sklearn.exceptions import NotFittedError +from sklearn.externals import array_api_extra as xpx from sklearn.utils._testing import ( assert_almost_equal, assert_array_almost_equal, assert_array_equal, ) from sklearn.utils.fixes import CSC_CONTAINERS -from sklearn.utils.stats import _weighted_percentile def _check_predict_proba(clf, X, y): @@ -631,11 +631,12 @@ def test_dummy_regressor_sample_weight(global_random_seed, n_samples=10): est = DummyRegressor(strategy="mean").fit(X, y, sample_weight) assert est.constant_ == np.average(y, weights=sample_weight) + method = "averaged_inverted_cdf" est = DummyRegressor(strategy="median").fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 50.0) + assert est.constant_ == xpx.quantile(y, 0.5, weights=sample_weight, method=method) est = DummyRegressor(strategy="quantile", quantile=0.95).fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 95.0) + assert est.constant_ == xpx.quantile(y, 0.95, weights=sample_weight, method=method) def test_dummy_regressor_on_3D_array(): diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 8be143e9c9e5b..6c6e62a9e01fb 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -1,10 +1,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from sklearn.utils._array_api import ( - _find_matching_floating_dtype, - get_namespace_and_device, -) +from sklearn.externals import array_api_extra as xpx def _weighted_percentile( @@ -62,9 +59,9 @@ def _weighted_percentile( Weights for each value in `array`. Must be same shape as `array` or of shape `(array.shape[0],)`. - percentile_rank: int or float, default=50 - The probability level of the percentile to compute, in percent. Must be between - 0 and 100. + percentile_rank: scalar or 1D array, default=50 + The probability level(s) of the percentile(s) to compute, in percent. Must be + between 0 and 100. If a 1D array, computes multiple percentiles. average : bool, default=False If `True`, uses the "averaged_inverted_cdf" quantile method, otherwise @@ -79,112 +76,23 @@ def _weighted_percentile( Returns ------- - percentile : scalar or 0D array if `array` 1D (or 0D), array if `array` 2D - Weighted percentile at the requested probability level. + percentile : scalar, 1D array, or 2D array + Weighted percentile at the requested probability level(s). + If `array` is 1D and `percentile_rank` is scalar, returns a scalar. + If `array` is 2D and `percentile_rank` is scalar, returns a 1D array + of shape `(array.shape[1],)` + If `array` is 1D and `percentile_rank` is 1D, returns a 1D array + of shape `(percentile_rank.shape[0],)` + If `array` is 2D and `percentile_rank` is 1D, returns a 2D array + of shape `(percentile_rank.shape[0], array.shape[1])` """ - xp, _, device = get_namespace_and_device(array) - # `sample_weight` should follow `array` for dtypes - floating_dtype = _find_matching_floating_dtype(array, xp=xp) - array = xp.asarray(array, dtype=floating_dtype, device=device) - sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) - - n_dim = array.ndim - if n_dim == 0: - return array - if array.ndim == 1: - array = xp.reshape(array, (-1, 1)) - # When sample_weight 1D, repeat for each array.shape[1] - if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]: - sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T - # Sort `array` and `sample_weight` along axis=0: - sorted_idx = xp.argsort(array, axis=0, stable=False) - sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0) - - # Set NaN values in `sample_weight` to 0. Only perform this operation if NaN - # values present to avoid temporary allocations of size `(n_samples, n_features)`. - n_features = array.shape[1] - largest_value_per_column = array[ - sorted_idx[-1, ...], xp.arange(n_features, device=device) - ] - # NaN values get sorted to end (largest value) - if xp.any(xp.isnan(largest_value_per_column)): - sorted_nan_mask = xp.take_along_axis(xp.isnan(array), sorted_idx, axis=0) - sorted_weights[sorted_nan_mask] = 0 - - # Compute the weighted cumulative distribution function (CDF) based on - # `sample_weight` and scale `percentile_rank` along it. - # - # Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to - # ensure that the result is of shape `(n_features, n_samples)` so - # `xp.searchsorted` calls take contiguous inputs as a result (for - # performance reasons). - weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1) - adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1] - - # Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528) - mask = adjusted_percentile_rank == 0 - adjusted_percentile_rank[mask] = xp.nextafter( - adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1 + method = "averaged_inverted_cdf" if average else "inverted_cdf" + return xpx.quantile( + array, + percentile_rank / 100, + axis=0, + method=method, + weights=sample_weight, + xp=xp, + nan_policy="omit", ) - # For each feature with index j, find sample index i of the scalar value - # `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that: - # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i]. - # Note `searchsorted` defaults to equality on the right, whereas Hyndman and Fan - # reference equation has equality on the left. - percentile_indices = xp.stack( - [ - xp.searchsorted( - weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx] - ) - for feature_idx in range(weight_cdf.shape[0]) - ], - ) - # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating - # point error (see #11813) - max_idx = sorted_idx.shape[0] - 1 - percentile_indices = xp.clip(percentile_indices, 0, max_idx) - - col_indices = xp.arange(array.shape[1], device=device) - percentile_in_sorted = sorted_idx[percentile_indices, col_indices] - - if average: - # From Hyndman and Fan (1996), `fraction_above` is `g` - fraction_above = ( - weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank - ) - is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps - percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) - percentile_plus_one_in_sorted = sorted_idx[ - percentile_plus_one_indices, col_indices - ] - # Handle case when next index ('plus one') has sample weight of 0 - zero_weight_cols = col_indices[ - sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 - ] - for col_idx in zero_weight_cols: - cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] - # Search for next index where `weighted_cdf` is greater - next_index = xp.searchsorted( - weight_cdf[col_idx, ...], cdf_val, side="right" - ) - # Handle case where there are trailing 0 sample weight samples - # and `percentile_indices` is already max index - if next_index >= max_idx: - # use original `percentile_indices` again - next_index = percentile_indices[col_idx] - - percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] - - result = xp.where( - is_fraction_above, - array[percentile_in_sorted, col_indices], - ( - array[percentile_in_sorted, col_indices] - + array[percentile_plus_one_in_sorted, col_indices] - ) - / 2, - ) - else: - result = array[percentile_in_sorted, col_indices] - - return result[0] if n_dim == 1 else result