-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
Fix: improve speed of trees with MAE criterion from O(n^2) to O(n log n) #32100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…dded print everywhere to debug; fixed some bugs
…al PR but not all
| # MAE split precomputations algorithm | ||
| # ============================================================================= | ||
|
|
||
| def _any_isnan_axis0(const float32_t[:, :] X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this one up, in the helpers section.
|
@adam2392 could you please have a look here? |
adam2392
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all. Thanks @cakedev0 for taking a look at this challenging, but impactful issue, and proposing a fix.
I took an initial glance. This overall looks like the right direction to me, so I want to make sure others take a look before we dive into the nitty stuff of making the PR mergable, and maintainable.
I have an open q: For decision trees, we can imagine imposing a quantile-criterion split (e.g. the pinball loss). Naively, I think we can make the WeightedHeaps work to maintain any sort of quantile right?
Perhaps @thomasjpfan wants to take a look as well before we dive deeper into the code.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
I never used this kind of tools. I might give it a try if I'm in the mood of learning new things ^^ |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started to review the code but won't have time to complete my review today, so here are a first few comments.
Something that seems to be missing from the tests are checks that the code works as expected for multi-output y for non-default criterions.
True! It should be easy enough to add such tests with some |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
| # WeightedFenwickTree data structure | ||
| # ============================================================================= | ||
|
|
||
| cdef class WeightedFenwickTree: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching up through comments.
Am I understanding that this class's function is slower when using np.empty + memoryviews? Is it the class itself, or something to do with if we try to train a lot of trees?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I measured to be slower was just one long run _py_precompute_absolute_errors, so not related to training a lot of trees. It was ~20% slower.
I think it's because there are structs in C behind Cython memory-views. And when doing mem_v[i] in Cython, it does mem_v->data[i] in C, or something like that.
But I think this was known be people who wrote sklearn/tree initially. Typically, in sort, which is what dominates the execution time usually, pointers are used and not memory views.
+1 for a concurrent PR then. |
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am slowly wrapping my head around the code of this PR, but unfortunately won't have time to finalize my review today. Still here is my current feedback inline below.
Also note for later: instead of computing the loss for a median or a specific quantile q, we could leverage this code to efficiently compute the aggregate loss for a uniform grid of quantiles values: a lot of the computation (e.g. compute_ranks, progressively adding points to the Fenwick tree) would be shared.
Integrating the pinball loss over q in [0, 1] is a way to estimate the CRPS, which is a strictly proper scoring rule for a probabilistic estimate, so in effect we would get a distributional tree estimator for quite cheap.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
…to mae-split-optim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (besides nitpicks below)! Thanks for the great PR.
I let @adam2392 do the merge if he is still +1 for merge after the latest changes.
Possible follow-ups:
- generalize to regression for an arbitrary quantile;
- add support for missing values (if not overly complex).
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Actually, this is very simple (even simplifies the current code base), and has nothing to do with criteria. Criteria don't interact with feature values, just with the target values and their ordering via |
|
I missed that PR despite the notification... |
This PR re-implements the way
DecisionTreeRegressor(criterion='absolute_error')works underneath for optimization purposes. The current algorithm for calculating the AE of a split incures a O(n^2) overall complexity for building a tree which quickly becomes impractical. My implementation makes it O(n log n) making it tremendously faster.For instance with d=2, n=100_000 and max_depth=1 (just one split), the execution time went from ~30s to ~100ms on my machine.
Referenced Issues
Fixes #9626 by reducing the complexity from O(n^2) to O(n log n).
Also fixes #32099 & #10725 (which are probably duplicates). But that's more of a side effect of re-implementing completely the criterion logic for MAE.
Supersedes #11649 (which was opened to fix #10725 7 years ago but never merged).
Explanation of my changes
The changes focus solely on the class
MAE(RegressionCriterion).Previous implementation had O(n^2) overall complexity emerging from several methods in this class
update: O(n) cost due to updating a data structure that maintains data sorted (WeightedMedianCalculator/WeightedPQueue). Called O(n) times to find the best split => O(n^2) overallchildren_impurity: O(n) due to looping over all the data points. Called O(n) times to find the best split => O(n^2) overallThose can't really be fixed by small local changes, as overall, the algorithm is O(n^2) independently of how you implement it. Hence a complete rewrite was needed. As discussed in this technical report I made, there are several efficient algorithms to solve the problem (computing the absolute errors for all the possible splits along one feature).
The one I chose initially was an intuitive adaptation of the well-known two-heap solution of the "find median from a data stream" problem. But even if it had a O(n log n) expected complexity, it can be O(n^2 log n) in some pathological cases. So after some discussions, it was chosen to implement an other solution: the "Fenwick tree option". This solution is based on a Fenwick tree, a data-structure specialized in efficient prefix sums computations and updates.
See the technical report for detailed explanation of the algorithm, but in short, the main steps are:
the value of those 4 prefix/suffix-sums can be found while searching for the median in the tree, and once you have those, the computation becomes O(1).
Iterate on the data from left to right to compute the AE for every possible left child. And iterate from right to left to compute the AE for every possible right child.
This logic is implemented in
tree/_criterion.pyx::precompute_absolute_errorsas I wanted to be able to unit test it.After some research I found a paper about the same problem. Their approach uses the two heaps idea and generalizes to arbitrary quantiles (as done in my follow-up PR), but it does not handle weighted samples. Also, the paper uses a more elaborate formula for the absolute error/loss computation than mine, TBH it looks unnecessarily complex.