Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b0f8bdf
implem done; clean-up & comments todo
cakedev0 Sep 28, 2025
2140c82
conform to array-API
cakedev0 Sep 28, 2025
84c0240
cleanup
cakedev0 Sep 28, 2025
7822673
comments; docstring; cleanups
cakedev0 Sep 28, 2025
c82c75f
swap functions order for easier diff
cakedev0 Sep 28, 2025
f6f877c
use new signature where useful
cakedev0 Sep 28, 2025
cad8614
update docstring for new signature
cakedev0 Sep 28, 2025
7f5d47f
adapt fully to array-API
cakedev0 Sep 28, 2025
649b271
fix array API compat
cakedev0 Sep 28, 2025
cda231c
another array API fix: TypeError: object of type 'Array' has no len()
cakedev0 Sep 28, 2025
9c4a5ad
more array-API fixes; tested locally; but I cant test everything I do…
cakedev0 Sep 28, 2025
4ea221e
tmp: old for benchmark
cakedev0 Sep 29, 2025
837c287
Merge branch 'main' into optim-weighted-percentile
ogrisel Oct 2, 2025
6194d15
Merge branch 'optim-weighted-percentile' of github.com:cakedev0/sciki…
cakedev0 Oct 9, 2025
f14eb58
Merge remote-tracking branch 'upstream/main' into optim-weighted-perc…
cakedev0 Oct 9, 2025
f96d334
remove comment about floating dtype
cakedev0 Oct 9, 2025
013725d
Fix device error
cakedev0 Oct 9, 2025
7b9e50b
mitigate perf loss with d>>1
cakedev0 Oct 13, 2025
1407248
Merge branch 'optim-weighted-percentile' of github.com:cakedev0/sciki…
cakedev0 Oct 13, 2025
4a8b9df
Merge remote-tracking branch 'upstream/main' into optim-weighted-perc…
cakedev0 Oct 13, 2025
35f6d8d
minor fix for average
cakedev0 Oct 13, 2025
fd8b2c9
WIP: inner func handles 2D but only 1 quantile
cakedev0 Oct 17, 2025
d59abf5
restore back prev implem. and loop to compute multiple percentiles wi…
cakedev0 Oct 17, 2025
357c337
Merge branch 'main' into optim-weighted-percentile
cakedev0 Oct 19, 2025
c1fbcdb
Merge remote-tracking branch 'upstream/main' into optim-weighted-perc…
cakedev0 Oct 20, 2025
cd45b94
Merge branch 'optim-weighted-percentile' into exp/use_xpx_quantile
cakedev0 Oct 26, 2025
e327b37
using xpx.quantile everywhere
cakedev0 Oct 26, 2025
de9b63c
wip
cakedev0 Oct 26, 2025
89e6503
wip
cakedev0 Oct 26, 2025
202430c
fixed backward compat test
cakedev0 Oct 26, 2025
84472cf
Merge branch 'update_xpx' into exp/use_xpx_quantile
cakedev0 Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions sklearn/_loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
LogLink,
MultinomialLogit,
)
from sklearn.externals import array_api_extra as xpx
from sklearn.utils import check_scalar
from sklearn.utils.stats import _weighted_percentile


# Note: The shape of raw_prediction for multiclass classifications are
Expand Down Expand Up @@ -588,7 +588,13 @@ def fit_intercept_only(self, y_true, sample_weight=None):
if sample_weight is None:
return np.median(y_true, axis=0)
else:
return _weighted_percentile(y_true, sample_weight, 50)
return xpx.quantile(
y_true,
0.5,
axis=0,
weights=sample_weight,
method="averaged_inverted_cdf",
)


class PinballLoss(BaseLoss):
Expand Down Expand Up @@ -646,12 +652,10 @@ def fit_intercept_only(self, y_true, sample_weight=None):
This is the weighted median of the target, i.e. over the samples
axis=0.
"""
if sample_weight is None:
return np.percentile(y_true, 100 * self.closs.quantile, axis=0)
else:
return _weighted_percentile(
y_true, sample_weight, 100 * self.closs.quantile
)
method = "linear" if sample_weight is None else "averaged_inverted_cdf"
return xpx.quantile(
y_true, self.closs.quantile, axis=0, method=method, weights=sample_weight
)


class HuberLoss(BaseLoss):
Expand Down Expand Up @@ -718,10 +722,15 @@ def fit_intercept_only(self, y_true, sample_weight=None):
# not to the residual y_true - raw_prediction. An estimator like
# HistGradientBoostingRegressor might then call it on the residual, e.g.
# fit_intercept_only(y_true - raw_prediction).
if sample_weight is None:
median = np.percentile(y_true, 50, axis=0)
else:
median = _weighted_percentile(y_true, sample_weight, 50)

method = "linear" if sample_weight is None else "inverted_cdf"
# XXX: it would be better to use method "averaged_inverted_cdf"
# for the weighted case
# (otherwise passing 1s weights is not equivalent to no weights)
# but this would break this test:
# ensemble/tests/test_gradient_boosting.py::test_huber_exact_backward_compat

median = xpx.quantile(y_true, 0.5, axis=0, method=method, weights=sample_weight)
diff = y_true - median
term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff))
return median + np.average(term, weights=sample_weight)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ def calibration_curve(

if strategy == "quantile": # Determine bin edges by distribution of data
quantiles = np.linspace(0, 1, n_bins + 1)
bins = np.percentile(y_prob, quantiles * 100)
bins = xpx.quantile(y_prob, quantiles)
elif strategy == "uniform":
bins = np.linspace(0.0, 1.0, n_bins + 1)
else:
Expand Down
27 changes: 12 additions & 15 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
RegressorMixin,
_fit_context,
)
from sklearn.externals import array_api_extra as xpx
from sklearn.utils import check_random_state
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.multiclass import class_distribution
from sklearn.utils.random import _random_choice_csc
from sklearn.utils.stats import _weighted_percentile
from sklearn.utils.validation import (
_check_sample_weight,
_num_samples,
Expand Down Expand Up @@ -581,27 +581,24 @@ def fit(self, X, y, sample_weight=None):
if sample_weight is None:
self.constant_ = np.median(y, axis=0)
else:
self.constant_ = [
_weighted_percentile(y[:, k], sample_weight, percentile_rank=50.0)
for k in range(self.n_outputs_)
]
self.constant_ = xpx.quantile(
y,
0.5,
axis=0,
weights=sample_weight,
method="averaged_inverted_cdf",
)

elif self.strategy == "quantile":
if self.quantile is None:
raise ValueError(
"When using `strategy='quantile', you have to specify the desired "
"quantile in the range [0, 1]."
)
percentile_rank = self.quantile * 100.0
if sample_weight is None:
self.constant_ = np.percentile(y, axis=0, q=percentile_rank)
else:
self.constant_ = [
_weighted_percentile(
y[:, k], sample_weight, percentile_rank=percentile_rank
)
for k in range(self.n_outputs_)
]
method = "linear" if sample_weight is None else "averaged_inverted_cdf"
self.constant_ = xpx.quantile(
y, float(self.quantile), axis=0, weights=sample_weight, method=method
)

elif self.strategy == "constant":
if self.constant is None:
Expand Down
8 changes: 6 additions & 2 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
predict_stages,
)
from sklearn.exceptions import NotFittedError
from sklearn.externals import array_api_extra as xpx
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import DOUBLE, DTYPE, TREE_LEAF
from sklearn.utils import check_array, check_random_state, column_or_1d
from sklearn.utils._param_validation import HasMethods, Interval, StrOptions
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.stats import _weighted_percentile
from sklearn.utils.validation import (
_check_sample_weight,
check_is_fitted,
Expand Down Expand Up @@ -275,7 +275,11 @@ def set_huber_delta(loss, y_true, raw_prediction, sample_weight=None):
"""Calculate and set self.closs.delta based on self.quantile."""
abserr = np.abs(y_true - raw_prediction.squeeze())
# sample_weight is always a ndarray, never None.
delta = _weighted_percentile(abserr, sample_weight, 100 * loss.quantile)
delta = xpx.quantile(
abserr, loss.quantile, axis=0, weights=sample_weight, method="inverted_cdf"
)
# XXX: it would probably be better to use method "averaged_inverted_cdf"
# see explanations of why we can't in HuberLoss.fit_intercept_only
loss.closs.delta = float(delta)


Expand Down
26 changes: 12 additions & 14 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
get_namespace,
get_namespace_and_device,
size,
xpx,
)
from sklearn.utils._array_api import _xlogy as xlogy
from sklearn.utils._param_validation import Interval, StrOptions, validate_params
from sklearn.utils.stats import _weighted_percentile
from sklearn.utils.validation import (
_check_sample_weight,
_num_samples,
Expand Down Expand Up @@ -923,8 +923,12 @@ def median_absolute_error(
if sample_weight is None:
output_errors = _median(xp.abs(y_pred - y_true), axis=0)
else:
output_errors = _weighted_percentile(
xp.abs(y_pred - y_true), sample_weight=sample_weight, average=True
output_errors = xpx.quantile(
xp.abs(y_pred - y_true),
0.5,
axis=0,
weights=sample_weight,
method="averaged_inverted_cdf",
)
if isinstance(multioutput, str):
if multioutput == "raw_values":
Expand Down Expand Up @@ -1820,17 +1824,11 @@ def d2_pinball_score(
multioutput="raw_values",
)

if sample_weight is None:
y_quantile = np.tile(
np.percentile(y_true, q=alpha * 100, axis=0), (len(y_true), 1)
)
else:
y_quantile = np.tile(
_weighted_percentile(
y_true, sample_weight=sample_weight, percentile_rank=alpha * 100
),
(len(y_true), 1),
)
method = "linear" if sample_weight is None else "averaged_inverted_cdf"
y_quantile = np.tile(
xpx.quantile(y_true, alpha, axis=0, weights=sample_weight, method=method),
(len(y_true), 1),
)

denominator = mean_pinball_loss(
y_true,
Expand Down
42 changes: 8 additions & 34 deletions sklearn/preprocessing/_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np

from sklearn.base import BaseEstimator, TransformerMixin, _fit_context
from sklearn.externals import array_api_extra as xpx
from sklearn.preprocessing._encoders import OneHotEncoder
from sklearn.utils import resample
from sklearn.utils._param_validation import Interval, Options, StrOptions
from sklearn.utils.stats import _weighted_percentile
from sklearn.utils.validation import (
_check_feature_names_in,
_check_sample_weight,
Expand Down Expand Up @@ -350,39 +350,13 @@ def fit(self, X, y=None, sample_weight=None):
bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1)

elif self.strategy == "quantile":
percentile_levels = np.linspace(0, 100, n_bins[jj] + 1)

# method="linear" is the implicit default for any numpy
# version. So we keep it version independent in that case by
# using an empty param dict.
percentile_kwargs = {}
if quantile_method != "linear" and sample_weight is None:
percentile_kwargs["method"] = quantile_method

if sample_weight is None:
bin_edges[jj] = np.asarray(
np.percentile(column, percentile_levels, **percentile_kwargs),
dtype=np.float64,
)
else:
# TODO: make _weighted_percentile accept an array of
# quantiles instead of calling it multiple times and
# sorting the column multiple times as a result.
average = (
True if quantile_method == "averaged_inverted_cdf" else False
)
bin_edges[jj] = np.asarray(
[
_weighted_percentile(
column,
sample_weight,
percentile_rank=p,
average=average,
)
for p in percentile_levels
],
dtype=np.float64,
)
quantile_levels = np.linspace(0, 1, n_bins[jj] + 1)
bin_edges[jj] = xpx.quantile(
column,
quantile_levels,
weights=sample_weight,
method=quantile_method,
)
elif self.strategy == "kmeans":
from sklearn.cluster import KMeans # fixes import loops

Expand Down
19 changes: 9 additions & 10 deletions sklearn/preprocessing/_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from scipy.special import comb

from sklearn.base import BaseEstimator, TransformerMixin, _fit_context
from sklearn.externals import array_api_extra as xpx
from sklearn.preprocessing._csr_polynomial_expansion import (
_calc_expanded_nnz,
_calc_total_nnz,
Expand All @@ -30,7 +31,6 @@
from sklearn.utils._mask import _get_mask
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.fixes import parse_version, sp_version
from sklearn.utils.stats import _weighted_percentile
from sklearn.utils.validation import (
FLOAT_DTYPES,
_check_feature_names_in,
Expand Down Expand Up @@ -784,18 +784,17 @@ def _get_base_knot_positions(X, n_knots=10, knots="uniform", sample_weight=None)
Knot positions (points) of base interval.
"""
if knots == "quantile":
percentile_ranks = 100 * np.linspace(
start=0, stop=1, num=n_knots, dtype=np.float64
)
quantile_ranks = np.linspace(start=0, stop=1, num=n_knots, dtype=np.float64)

if sample_weight is None:
knots = np.nanpercentile(X, percentile_ranks, axis=0)
knots = np.nanquantile(X, quantile_ranks, axis=0)
else:
knots = np.array(
[
_weighted_percentile(X, sample_weight, percentile_rank)
for percentile_rank in percentile_ranks
]
knots = xpx.quantile(
X,
quantile_ranks,
axis=0,
weights=sample_weight,
method="averaged_inverted_cdf",
)

else:
Expand Down
7 changes: 4 additions & 3 deletions sklearn/tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from sklearn.base import clone
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.exceptions import NotFittedError
from sklearn.externals import array_api_extra as xpx
from sklearn.utils._testing import (
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
)
from sklearn.utils.fixes import CSC_CONTAINERS
from sklearn.utils.stats import _weighted_percentile


def _check_predict_proba(clf, X, y):
Expand Down Expand Up @@ -631,11 +631,12 @@ def test_dummy_regressor_sample_weight(global_random_seed, n_samples=10):
est = DummyRegressor(strategy="mean").fit(X, y, sample_weight)
assert est.constant_ == np.average(y, weights=sample_weight)

method = "averaged_inverted_cdf"
est = DummyRegressor(strategy="median").fit(X, y, sample_weight)
assert est.constant_ == _weighted_percentile(y, sample_weight, 50.0)
assert est.constant_ == xpx.quantile(y, 0.5, weights=sample_weight, method=method)

est = DummyRegressor(strategy="quantile", quantile=0.95).fit(X, y, sample_weight)
assert est.constant_ == _weighted_percentile(y, sample_weight, 95.0)
assert est.constant_ == xpx.quantile(y, 0.95, weights=sample_weight, method=method)


def test_dummy_regressor_on_3D_array():
Expand Down
Loading
Loading