Skip to content

Conversation

@cakedev0
Copy link
Contributor

@cakedev0 cakedev0 commented Sep 3, 2025

This PR re-implements the way DecisionTreeRegressor(criterion='absolute_error') works underneath for optimization purposes. The current algorithm for calculating the AE of a split incures a O(n^2) overall complexity for building a tree which quickly becomes impractical. My implementation makes it O(n log n) making it tremendously faster.

For instance with d=2, n=100_000 and max_depth=1 (just one split), the execution time went from ~30s to ~100ms on my machine.

Referenced Issues

Fixes #9626 by reducing the complexity from O(n^2) to O(n log n).
Also fixes #32099 & #10725 (which are probably duplicates). But that's more of a side effect of re-implementing completely the criterion logic for MAE.

Supersedes #11649 (which was opened to fix #10725 7 years ago but never merged).

Explanation of my changes

The changes focus solely on the class MAE(RegressionCriterion).

Previous implementation had O(n^2) overall complexity emerging from several methods in this class

  • in update: O(n) cost due to updating a data structure that maintains data sorted (WeightedMedianCalculator/WeightedPQueue). Called O(n) times to find the best split => O(n^2) overall
  • in children_impurity: O(n) due to looping over all the data points. Called O(n) times to find the best split => O(n^2) overall

Those can't really be fixed by small local changes, as overall, the algorithm is O(n^2) independently of how you implement it. Hence a complete rewrite was needed. As discussed in this technical report I made, there are several efficient algorithms to solve the problem (computing the absolute errors for all the possible splits along one feature).

The one I chose initially was an intuitive adaptation of the well-known two-heap solution of the "find median from a data stream" problem. But even if it had a O(n log n) expected complexity, it can be O(n^2 log n) in some pathological cases. So after some discussions, it was chosen to implement an other solution: the "Fenwick tree option". This solution is based on a Fenwick tree, a data-structure specialized in efficient prefix sums computations and updates.

See the technical report for detailed explanation of the algorithm, but in short, the main steps are:

  • insert a new element (y, w) in the tree, and search by prefix sum to find the weighted median: O(log n)
  • rewrite the AE computation by taking advantage of the following calculations:
    $\sum_i w_i | y_i - m | = \sum_{y_i >= m} w_i(y_i - m) + \sum_{y_i < m} w_i(m - y_i) $
    $= \sum_{y_i >= m} w_i y_i - m \sum_{y_i >= m} w_i + m \sum_{y_i < m} w_i - \sum_{y_i < m} w_i y_i $
    the value of those 4 prefix/suffix-sums can be found while searching for the median in the tree, and once you have those, the computation becomes O(1).

Iterate on the data from left to right to compute the AE for every possible left child. And iterate from right to left to compute the AE for every possible right child.

This logic is implemented in tree/_criterion.pyx::precompute_absolute_errors as I wanted to be able to unit test it.

After some research I found a paper about the same problem. Their approach uses the two heaps idea and generalizes to arbitrary quantiles (as done in my follow-up PR), but it does not handle weighted samples. Also, the paper uses a more elaborate formula for the absolute error/loss computation than mine, TBH it looks unnecessarily complex.

@github-actions
Copy link

github-actions bot commented Sep 3, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: d12585d. Link to the linter CI: here

# MAE split precomputations algorithm
# =============================================================================

def _any_isnan_axis0(const float32_t[:, :] X):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this one up, in the helpers section.

@cakedev0 cakedev0 marked this pull request as ready for review September 4, 2025 16:52
@adrinjalali
Copy link
Member

@adam2392 could you please have a look here?

@adam2392 adam2392 self-requested a review September 9, 2025 00:02
Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all. Thanks @cakedev0 for taking a look at this challenging, but impactful issue, and proposing a fix.

I took an initial glance. This overall looks like the right direction to me, so I want to make sure others take a look before we dive into the nitty stuff of making the PR mergable, and maintainable.

I have an open q: For decision trees, we can imagine imposing a quantile-criterion split (e.g. the pinball loss). Naively, I think we can make the WeightedHeaps work to maintain any sort of quantile right?

Perhaps @thomasjpfan wants to take a look as well before we dive deeper into the code.

@cakedev0
Copy link
Contributor Author

It might be interesting to double-check that everything behaves as expected using a memory profiler such as scalene or memray but from afar, it looks ok.

I never used this kind of tools. I might give it a try if I'm in the mood of learning new things ^^

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to review the code but won't have time to complete my review today, so here are a first few comments.

Something that seems to be missing from the tests are checks that the code works as expected for multi-output y for non-default criterions.

@cakedev0
Copy link
Contributor Author

Something that seems to be missing from the tests are checks that the code works as expected for multi-output y for non-default criterions.

True!

It should be easy enough to add such tests with some @pytest.mark.parametrize. I can open another PR to do that, as I believe it's out-of-scope for this PR (though it would increase confidence in this PR to have such tests).

cakedev0 and others added 2 commits October 21, 2025 16:27
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
# WeightedFenwickTree data structure
# =============================================================================

cdef class WeightedFenwickTree:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching up through comments.

Am I understanding that this class's function is slower when using np.empty + memoryviews? Is it the class itself, or something to do with if we try to train a lot of trees?

https://github.com/scikit-learn/scikit-learn/pull/32100/files/7523930d97e42306bda44682a7a3efb8c91d71c6#r2445862261

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I measured to be slower was just one long run _py_precompute_absolute_errors, so not related to training a lot of trees. It was ~20% slower.

I think it's because there are structs in C behind Cython memory-views. And when doing mem_v[i] in Cython, it does mem_v->data[i] in C, or something like that.

But I think this was known be people who wrote sklearn/tree initially. Typically, in sort, which is what dominates the execution time usually, pointers are used and not memory views.

@ogrisel
Copy link
Member

ogrisel commented Oct 23, 2025

It should be easy enough to add such tests with some @pytest.mark.parametrize. I can open another PR to do that, as I believe it's out-of-scope for this PR (though it would increase confidence in this PR to have such tests).

+1 for a concurrent PR then.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slowly wrapping my head around the code of this PR, but unfortunately won't have time to finalize my review today. Still here is my current feedback inline below.

Also note for later: instead of computing the loss for a median or a specific quantile q, we could leverage this code to efficiently compute the aggregate loss for a uniform grid of quantiles values: a lot of the computation (e.g. compute_ranks, progressively adding points to the Fenwick tree) would be shared.

Integrating the pinball loss over q in [0, 1] is a way to estimate the CRPS, which is a strictly proper scoring rule for a probabilistic estimate, so in effect we would get a distributional tree estimator for quite cheap.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (besides nitpicks below)! Thanks for the great PR.

I let @adam2392 do the merge if he is still +1 for merge after the latest changes.

Possible follow-ups:

  • generalize to regression for an arbitrary quantile;
  • add support for missing values (if not overly complex).

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@cakedev0
Copy link
Contributor Author

add support for missing values (if not overly complex).

Actually, this is very simple (even simplifies the current code base), and has nothing to do with criteria. Criteria don't interact with feature values, just with the target values and their ordering via sample_indices, so they shouldn't have anything to do with missing values in features. See my PR: #32119

@ogrisel
Copy link
Member

ogrisel commented Oct 28, 2025

I missed that PR despite the notification...

@adam2392 adam2392 self-requested a review October 29, 2025 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

8 participants