diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 8b43680e1f5ab..ff286af29d6a2 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -77,6 +77,7 @@ "friedman_mse": _criterion.FriedmanMSE, "absolute_error": _criterion.MAE, "poisson": _criterion.Poisson, + "pinball": _criterion.Pinball, } DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter} @@ -383,7 +384,14 @@ def _fit( self.n_outputs_, self.n_classes_ ) else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + args = (self.n_outputs_, n_samples) + if self.criterion == "pinball": + args = (*args, self.pinball_alpha) + if self.criterion == "absolute_error": + # FIXME: this is coupled with code at a much lower level + # because of the inheritance behavior of __cinit__ + args = (*args, 0.5) + criterion = CRITERIA_REG[self.criterion](*args) else: # Make a deepcopy in case the criterion has mutable attributes that # might be shared and modified concurrently during parallel fitting @@ -1338,9 +1346,18 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): _parameter_constraints: dict = { **BaseDecisionTree._parameter_constraints, "criterion": [ - StrOptions({"squared_error", "friedman_mse", "absolute_error", "poisson"}), + StrOptions( + { + "squared_error", + "friedman_mse", + "absolute_error", + "poisson", + "pinball", + } + ), Hidden(Criterion), ], + "pinball_alpha": [Interval(RealNotInt, 0.0, 1.0, closed="neither")], } def __init__( @@ -1358,6 +1375,7 @@ def __init__( min_impurity_decrease=0.0, ccp_alpha=0.0, monotonic_cst=None, + pinball_alpha=0.5, ): super().__init__( criterion=criterion, @@ -1373,6 +1391,7 @@ def __init__( ccp_alpha=ccp_alpha, monotonic_cst=monotonic_cst, ) + self.pinball_alpha = pinball_alpha @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y, sample_weight=None, check_input=True): @@ -1971,6 +1990,7 @@ def __init__( max_leaf_nodes=None, ccp_alpha=0.0, monotonic_cst=None, + pinball_alpha=0.5, ): super().__init__( criterion=criterion, @@ -1985,6 +2005,7 @@ def __init__( random_state=random_state, ccp_alpha=ccp_alpha, monotonic_cst=monotonic_cst, + pinball_alpha=pinball_alpha, ) def __sklearn_tags__(self): diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index fa7925597b9b8..6b99a4a8803d2 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1176,7 +1176,7 @@ cdef class MSE(RegressionCriterion): # Helper for MAE criterion: -cdef void precompute_absolute_errors( +cdef void precompute_pinball_losses( const float64_t[:, ::1] ys, const float64_t[:] sample_weight, const intp_t[:] sample_indices, @@ -1185,18 +1185,19 @@ cdef void precompute_absolute_errors( intp_t k, intp_t start, intp_t end, - float64_t[::1] abs_errors, - float64_t[::1] medians + float64_t q, + float64_t[::1] losses, + float64_t[::1] quantiles ) noexcept nogil: """ - Fill `abs_errors` and `medians`. + Fill `losses` and `quantiles`. If start < end: - Computes the "prefix" AEs/medians, i.e the AEs for each set of indices + Computes the "prefix" losses/quantiles, i.e the losses for each set of indices sample_indices[start:start + i] with i in {1, ..., n} where n = end - start Else: - Computes the "suffix" AEs/medians, i.e the AEs for each set of indices + Computes the "suffix" losses/quantiles, i.e the losses for each set of indices sample_indices[i:] with i in {0, ..., n-1} Parameters @@ -1215,24 +1216,27 @@ cdef void precompute_absolute_errors( Start index in `sample_indices` end : intp_t End index (exclusive) in `sample_indices` - abs_errors : float64_t[::1] + q : float64_t + Probability for the quantile / alpha for the pinball loss + losses : float64_t[::1] array to store (increment) the computed absolute errors. Shape: (n,) with n := end - start - medians : float64_t[::1] - array to store (overwrite) the computed medians. Shape: (n,) + quantiles : float64_t[::1] + array to store (overwrite) the computed quantiles. Shape: (n,) Complexity: O(n log n) This algorithm is an adaptation of the two heaps solution of the "find median from a data stream" problem See for instance: https://www.geeksforgeeks.org/dsa/median-of-stream-of-integers-running-integers/ - But here, it's the weighted median and we also need to compute the AE, so: + But here, it's the weighted quantile and we also need to compute the pinball loss, so: - instead of balancing the heaps based on their number of elements, rebalance them based on the summed weights of their elements - - rewrite the AE computation by splitting the sum between elements - above and below the median, which allow to express it as a simple + - rewrite the pinball loss computation by splitting the sum between + elements above and below the median, which allow to express it as a simple O(1) computation. - See the maths in the PR desc: + See the maths in the PR description: TODO + Also this the PR for the initial implementation (weighted median and absolute error): https://github.com/scikit-learn/scikit-learn/pull/32100 """ cdef intp_t j, p, i, step, n @@ -1250,8 +1254,9 @@ cdef void precompute_absolute_errors( cdef float64_t y cdef float64_t w = 1.0 cdef float64_t top_val, top_weight - cdef float64_t median = 0.0 - cdef float64_t half_weight + cdef float64_t quantile = 0.0 + cdef float64_t split_weight + cdef float64_t total_weight p = start for _ in range(n): @@ -1268,40 +1273,47 @@ cdef void precompute_absolute_errors( else: below.push(y, w) - half_weight = (above.total_weight + below.total_weight) / 2.0 + total_weight = above.total_weight + below.total_weight + split_weight = total_weight - q * total_weight + # ^ doing this instead of total_weight * (1 - q) to align + # with the rounding-errors in implementation of + # sklearn.utils.stats._weighted_percentile # Rebalance heaps - while above.total_weight < half_weight and not below.is_empty(): + while above.total_weight < split_weight and not below.is_empty(): below.pop(&top_val, &top_weight) above.push(top_val, top_weight) while ( not above.is_empty() - and (above.total_weight - above.top_weight()) >= half_weight + and (above.total_weight - above.top_weight()) >= split_weight ): above.pop(&top_val, &top_weight) below.push(top_val, top_weight) - # Current median - if above.total_weight == half_weight: - median = (above.top() + below.top()) / 2. + # Current quantile + if above.total_weight == split_weight: + # above and below heaps are exactly balanced + # we choose the midpoint for determinism to match with + # sklearn.utils.stats._weighted_percentile(..., average=True) + quantile = 0.5 * (above.top() + below.top()) else: - median = above.top() - medians[j] = median - abs_errors[j] += ( - (below.total_weight - above.total_weight) * median - - below.weighted_sum - + above.weighted_sum + quantile = above.top() + quantiles[j] = quantile + losses[j] += ( + q * (above.weighted_sum - quantile * above.total_weight) + + (1 - q) * (quantile * below.total_weight - below.weighted_sum) ) p += step j += step -def _py_precompute_absolute_errors( +def _py_precompute_pinball_losses( const float64_t[:, ::1] ys, const float64_t[:] sample_weight, const intp_t[:] sample_indices, const intp_t start, const intp_t end, + float64_t q=0.5, ): """Used for testing precompute_absolute_errors.""" cdef: @@ -1309,17 +1321,17 @@ def _py_precompute_absolute_errors( WeightedHeap above = WeightedHeap(n, True) WeightedHeap below = WeightedHeap(n, False) intp_t k = 0 - float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) - float64_t[::1] medians = np.zeros(n, dtype=np.float64) + float64_t[::1] losses = np.zeros(n, dtype=np.float64) + float64_t[::1] quantiles = np.zeros(n, dtype=np.float64) - precompute_absolute_errors( + precompute_pinball_losses( ys, sample_weight, sample_indices, above, below, - k, start, end, abs_errors, medians + k, start, end, q, losses, quantiles ) - return np.asarray(abs_errors), np.asarray(medians) + return np.asarray(losses), np.asarray(quantiles) -cdef class MAE(Criterion): +cdef class Pinball(Criterion): r"""Mean absolute error impurity criterion. MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true @@ -1328,15 +1340,16 @@ cdef class MAE(Criterion): It has almost nothing in common with other regression criterions so it doesn't inherit from RegressionCriterion """ - cdef float64_t[::1] node_medians - cdef float64_t[::1] left_abs_errors - cdef float64_t[::1] right_abs_errors - cdef float64_t[::1] left_medians - cdef float64_t[::1] right_medians + cdef float64_t alpha + cdef float64_t[::1] node_quantiles + cdef float64_t[::1] left_pinball_losses + cdef float64_t[::1] right_pinball_losses + cdef float64_t[::1] left_quantiles + cdef float64_t[::1] right_quantiles cdef WeightedHeap above cdef WeightedHeap below - def __cinit__(self, intp_t n_outputs, intp_t n_samples): + def __cinit__(self, intp_t n_outputs, intp_t n_samples, float64_t alpha): """Initialize parameters for this criterion. Parameters @@ -1347,6 +1360,8 @@ cdef class MAE(Criterion): n_samples : intp_t The total number of samples to fit on """ + self.alpha = alpha + # Default values self.start = 0 self.pos = 0 @@ -1359,14 +1374,14 @@ cdef class MAE(Criterion): self.weighted_n_left = 0.0 self.weighted_n_right = 0.0 - self.node_medians = np.zeros(n_outputs, dtype=np.float64) + self.node_quantiles = np.zeros(n_outputs, dtype=np.float64) # Note: this criterion has a n_samples x 64 bytes memory footprint, which is # fine as it's instantiated only once to build an entire tree - self.left_abs_errors = np.empty(n_samples, dtype=np.float64) - self.right_abs_errors = np.empty(n_samples, dtype=np.float64) - self.left_medians = np.empty(n_samples, dtype=np.float64) - self.right_medians = np.empty(n_samples, dtype=np.float64) + self.left_pinball_losses = np.empty(n_samples, dtype=np.float64) + self.right_pinball_losses = np.empty(n_samples, dtype=np.float64) + self.left_quantiles = np.empty(n_samples, dtype=np.float64) + self.right_quantiles = np.empty(n_samples, dtype=np.float64) self.above = WeightedHeap(n_samples, True) # min-heap self.below = WeightedHeap(n_samples, False) # max-heap @@ -1432,37 +1447,37 @@ cdef class MAE(Criterion): self.pos = self.start n_bytes = self.n_node_samples * sizeof(float64_t) - memset(&self.left_abs_errors[0], 0, n_bytes) - memset(&self.right_abs_errors[0], 0, n_bytes) + memset(&self.left_pinball_losses[0], 0, n_bytes) + memset(&self.right_pinball_losses[0], 0, n_bytes) # Precompute absolute errors (summed over each output) - # and medians (used only when n_outputs=1) + # and quantiles (used only when n_outputs=1) # of the right and left child of all possible splits # for the current ordering of `sample_indices` # Precomputation is needed here and can't be done step-by-step in the update method # like for other criterions. Indeed, we don't have efficient ways to update right child - # statistics when removing samples from it. So we compute right child AEs/medians by + # statistics when removing samples from it. So we compute right child AEs/quantiles by # traversing from right to left (and hence only adding samples). for k in range(self.n_outputs): - # Note that at each iteration of this loop, we overwrite `self.left_medians` - # and `self.right_medians`. They are used to check for monoticity constraints, + # Note that at each iteration of this loop, we overwrite `self.left_quantiles` + # and `self.right_quantiles`. They are used to check for monoticity constraints, # which are allowed only with n_outputs=1. - precompute_absolute_errors( + precompute_pinball_losses( self.y, self.sample_weight, self.sample_indices, - self.above, self.below, k, self.start, self.end, - # left_abs_errors is incremented, left_medians is overwritten - self.left_abs_errors, self.left_medians + self.above, self.below, k, self.start, self.end, self.alpha, + # left_abs_errors is incremented, left_quantiles is overwritten + self.left_pinball_losses, self.left_quantiles ) # For the right child, we consider samples from end-1 to start-1 - # i.e., reversed, and abs error & median are filled in reverse order to. - precompute_absolute_errors( + # i.e., reversed, and abs error & quantile are filled in reverse order to. + precompute_pinball_losses( self.y, self.sample_weight, self.sample_indices, - self.above, self.below, k, self.end - 1, self.start - 1, - # right_abs_errors is incremented, right_medians is overwritten - self.right_abs_errors, self.right_medians + self.above, self.below, k, self.end - 1, self.start - 1, self.alpha, + # right_abs_errors is incremented, right_quantiles is overwritten + self.right_pinball_losses, self.right_quantiles ) - # Store the median for the current node - self.node_medians[k] = self.right_medians[0] + # Store the quantile for the current node + self.node_quantiles[k] = self.right_quantiles[0] return 0 @@ -1499,7 +1514,7 @@ cdef class MAE(Criterion): """Computes the node value of sample_indices[start:end] into dest.""" cdef intp_t k for k in range(self.n_outputs): - dest[k] = self.node_medians[k] + dest[k] = self.node_quantiles[k] cdef inline float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average @@ -1510,8 +1525,8 @@ cdef class MAE(Criterion): """ cdef intp_t j = self.pos - self.start return ( - self.left_medians[j - 1] - + self.right_medians[j] + self.left_quantiles[j - 1] + + self.right_quantiles[j] ) / 2 cdef inline bint check_monotonicity( @@ -1525,7 +1540,7 @@ cdef class MAE(Criterion): return self._check_monotonicity( monotonic_cst, lower_bound, upper_bound, - self.left_medians[j - 1], self.right_medians[j]) + self.left_quantiles[j - 1], self.right_quantiles[j]) cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. @@ -1537,7 +1552,7 @@ cdef class MAE(Criterion): Time complexity: O(1) (precomputed in `.reset()`) """ return ( - self.right_abs_errors[0] + self.right_pinball_losses[0] / (self.weighted_n_node_samples * self.n_outputs) ) @@ -1556,20 +1571,20 @@ cdef class MAE(Criterion): # if pos == start, left child is empty, hence impurity is 0 if self.pos > self.start: - impurity_left += self.left_abs_errors[j - 1] + impurity_left += self.left_pinball_losses[j - 1] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) # if pos == end, right child is empty, hence impurity is 0 if self.pos < self.end: - impurity_right += self.right_abs_errors[j] + impurity_right += self.right_pinball_losses[j] p_impurity_right[0] = impurity_right / (self.weighted_n_right * self.n_outputs) - # those 2 methods are copied from the RegressionCriterion abstract class: def __reduce__(self): - return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + return (type(self), (self.n_outputs, self.n_samples, self.alpha), self.__getstate__()) + # this method is copied from the RegressionCriterion abstract class: cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" if dest[0] < lower_bound: @@ -1578,6 +1593,24 @@ cdef class MAE(Criterion): dest[0] = upper_bound +cdef class MAE(Pinball): + """ + The median is just the quantile alpha=0.5 + And the absolute error is twice the pinball_loss (with alpha=0.5) + """ + + # FIXME/XXX: Trust the instanciater to pass alpha=0.5 to the __cinit__... + + cdef float64_t node_impurity(self) noexcept nogil: + return 2 * Pinball.node_impurity(self) + + cdef void children_impurity(self, float64_t* p_impurity_left, + float64_t* p_impurity_right) noexcept nogil: + Pinball.children_impurity(self, p_impurity_left, p_impurity_right) + p_impurity_left[0] *= 2 + p_impurity_right[0] *= 2 + + cdef class FriedmanMSE(MSE): """Mean squared error impurity criterion with improvement score by Friedman. diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index fff2a47769b2c..bc07544f1d87d 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -20,7 +20,12 @@ from sklearn.dummy import DummyRegressor from sklearn.exceptions import NotFittedError from sklearn.impute import SimpleImputer -from sklearn.metrics import accuracy_score, mean_poisson_deviance, mean_squared_error +from sklearn.metrics import ( + accuracy_score, + mean_pinball_loss, + mean_poisson_deviance, + mean_squared_error, +) from sklearn.model_selection import cross_val_score, train_test_split from sklearn.pipeline import make_pipeline from sklearn.random_projection import _sparse_random_matrix @@ -36,7 +41,7 @@ DENSE_SPLITTERS, SPARSE_SPLITTERS, ) -from sklearn.tree._criterion import _py_precompute_absolute_errors +from sklearn.tree._criterion import _py_precompute_pinball_losses from sklearn.tree._partitioner import _py_sort from sklearn.tree._tree import ( NODE_DTYPE, @@ -1855,13 +1860,15 @@ def _pickle_copy(obj): assert n_outputs == n_outputs_ assert_array_equal(n_classes, n_classes_) - for _, typename in CRITERIA_REG.items(): - criteria = typename(n_outputs, n_samples) + for name, typename in CRITERIA_REG.items(): + args = (n_outputs, n_samples) + if name == "absolute_error" or name == "pinball": + args = (*args, 0.5) + criteria = typename(*args) result = copy_func(criteria).__reduce__() - typename_, (n_outputs_, n_samples_), _ = result + typename_, args_, _ = result assert typename == typename_ - assert n_outputs == n_outputs_ - assert n_samples == n_samples_ + assert args == args_ @pytest.mark.parametrize("sparse_container", [None] + CSC_CONTAINERS) @@ -2918,7 +2925,8 @@ def test_sort_log2_build(): assert_array_equal(samples, expected_samples) -def test_absolute_errors_precomputation_function(global_random_seed): +@pytest.mark.parametrize("q", [0.5, 0.2, 0.9, 0.4, 0.75]) +def test_pinball_loss_precomputation_function(q, global_random_seed): """ Test the main bit of logic of the MAE(RegressionCriterion) class (used by DecisionTreeRegressor(criterion="absolute_error")). @@ -2928,33 +2936,39 @@ def test_absolute_errors_precomputation_function(global_random_seed): part of the computation, in case of major refactor of the MAE class, it can be safely removed. """ + global_random_seed = np.random.choice(10**7) - def compute_prefix_abs_errors_naive(y, w): + def compute_prefix_losses_naive(y, w): + """ + Computes the pinball loss for all (y[:i], w[:i]) + Naive: O(n^2 log n) + """ y = y.ravel().copy() - medians = [ - _weighted_percentile(y[:i], w[:i], 50, average=True) + quantiles = [ + _weighted_percentile(y[:i], w[:i], q * 100, average=True) for i in range(1, y.size + 1) ] - errors = [ - (np.abs(y[:i] - m) * w[:i]).sum() - for i, m in zip(range(1, y.size + 1), medians) + losses = [ + mean_pinball_loss(y[:i], np.full(i, quantile), sample_weight=w[:i], alpha=q) + * w[:i].sum() + for i, quantile in zip(range(1, y.size + 1), quantiles) ] - return np.array(errors), np.array(medians) + return np.array(losses), np.array(quantiles) def assert_same_results(y, w, indices, reverse=False): args = (n - 1, -1) if reverse else (0, n) - abs_errors, medians = _py_precompute_absolute_errors(y, w, indices, *args) + losses, quantiles = _py_precompute_pinball_losses(y, w, indices, *args, q=q) y_sorted = y[indices] w_sorted = w[indices] if reverse: y_sorted = y_sorted[::-1] w_sorted = w_sorted[::-1] - abs_errors_, medians_ = compute_prefix_abs_errors_naive(y_sorted, w_sorted) + losses_, quantiles_ = compute_prefix_losses_naive(y_sorted, w_sorted) if reverse: - abs_errors_ = abs_errors_[::-1] - medians_ = medians_[::-1] - assert_allclose(abs_errors, abs_errors_, atol=1e-12) - assert_allclose(medians, medians_, atol=1e-12) + losses_ = losses_[::-1] + quantiles_ = quantiles_[::-1] + assert_allclose(losses, losses_, atol=1e-12) + assert_allclose(quantiles, quantiles_, atol=1e-12) rng = np.random.default_rng(global_random_seed)