Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENHANCEMENT: Add a prediction function that also calculates uncertainty in the prediction #584

Conversation

degenfabian
Copy link
Contributor

@degenfabian degenfabian commented Oct 31, 2024

This PR intends to fix issue #235. Since release v0.3.0 individual bagged models are stored inside the EBM which made it possible to create a prediction function that also calculates uncertainty by calculating the mean and standard deviation across bags (as suggested in the issue #235 by @interpret-ml).

Since this issue was created, some implementation details were changed and I was not able to access the transform function from EBMPreprocessor anymore. Therefore I edited the function construct_bins in _preprocessor.py to return the main_preprocessor and then safe it as class attribute in EBMModel.

I am not sure if it should be kept this way because an earlier release removed this attribute from EBMModel, but I do not know how else to access the functionality of the transform method (except duplicating it). I would be happy about guidance regarding this.

Another implementation detail which has changed was the structure of bins in the transform function, which threw an error when calling the discretize function. I fixed this by always accessing the last element in the bins list, but I am also happy about guidance here for cleaner or more robust solutions.

Last but not least, this currently only works if max_bins == max_interaction_bins, so at the moment I am just informing the user when calling that these attributes have to be equal.

I also tested the uncertainty estimates on the breast cancer dataset from sklearn and have attached the plot I made to this PR. I think it looks good :)
Figure_1

Copy link

codecov bot commented Oct 31, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 75.73%. Comparing base (e1182f1) to head (b4b07a5).
Report is 1 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop     #584      +/-   ##
===========================================
+ Coverage    75.71%   75.73%   +0.02%     
===========================================
  Files           72       72              
  Lines         9101     9109       +8     
===========================================
+ Hits          6891     6899       +8     
  Misses        2210     2210              
Flag Coverage Δ
bdist_linux_310_python 75.38% <100.00%> (+0.13%) ⬆️
bdist_linux_311_python 75.40% <100.00%> (+0.02%) ⬆️
bdist_linux_312_python 75.38% <100.00%> (+0.07%) ⬆️
bdist_linux_39_python 75.25% <100.00%> (-0.09%) ⬇️
bdist_mac_310_python 75.50% <100.00%> (+0.02%) ⬆️
bdist_mac_311_python 75.50% <100.00%> (+0.05%) ⬆️
bdist_mac_312_python 75.50% <100.00%> (-0.04%) ⬇️
bdist_mac_39_python 75.56% <100.00%> (+0.23%) ⬆️
bdist_win_310_python 75.40% <100.00%> (-0.09%) ⬇️
bdist_win_311_python 75.59% <100.00%> (+0.07%) ⬆️
bdist_win_312_python 75.61% <100.00%> (+0.02%) ⬆️
bdist_win_39_python 75.58% <100.00%> (+0.04%) ⬆️
sdist_linux_310_python 75.32% <100.00%> (+0.09%) ⬆️
sdist_linux_311_python 75.34% <100.00%> (+0.04%) ⬆️
sdist_linux_312_python 75.32% <100.00%> (+0.13%) ⬆️
sdist_linux_39_python 75.21% <100.00%> (-0.06%) ⬇️
sdist_mac_310_python 75.48% <100.00%> (+0.07%) ⬆️
sdist_mac_311_python 75.39% <100.00%> (+0.04%) ⬆️
sdist_mac_312_python 75.27% <100.00%> (-0.12%) ⬇️
sdist_mac_39_python 75.25% <100.00%> (-0.12%) ⬇️
sdist_win_310_python 75.59% <100.00%> (+0.09%) ⬆️
sdist_win_311_python 75.59% <100.00%> (+0.02%) ⬆️
sdist_win_312_python 75.59% <100.00%> (-0.01%) ⬇️
sdist_win_39_python 75.48% <100.00%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@degenfabian degenfabian marked this pull request as ready for review November 2, 2024 11:50
@paulbkoch
Copy link
Collaborator

Thanks for this PR @degenfabian. I do have some suggestions regarding your questions. There is an internal function called ebm_predict_scores that you can use to simplify this processing. You can see an example of it being used here:

https://github.com/interpretml/interpret/blob/develop/python/interpret-core/interpret/glassbox/_ebm/_ebm.py#L1453-L1473

and also here: https://github.com/interpretml/interpret/blob/develop/python/interpret-core/interpret/glassbox/_ebm/_ebm.py#L1040-L1049

The ebm_predict_scores function is designed to handle individual bagged models. To do this you would replace self.intercept_ and self.term_scores_ with slices from self.bagged_intercept_ and self.bagged_scores_. With this change you shouldn't need the preprocessor anymore, and it should also handle the condition where max_bins != max_interaction_bins.

@degenfabian
Copy link
Contributor Author

degenfabian commented Nov 3, 2024

Thank you so much for the feedback @paulbkoch, it was really helpful! I tried implementing it to the best of my ability and it looks a lot cleaner now! Do you mind taking another look at it? As you said it, max_bins != max_interaction_bins is not an issue anymore with this updated version!

Attached is another plot with decision boundary on the left and prediction uncertainty on the right, calculated with the updated version of the function.
Figure_1

@degenfabian degenfabian force-pushed the add-prediction-function-with-uncertainty branch from e760490 to 15216ae Compare November 3, 2024 13:58
@paulbkoch
Copy link
Collaborator

Thanks @degenfabian, the code looks good. Before I merge though, could you add a test?

@degenfabian
Copy link
Contributor Author

Of course, just added them! So sorry I didn't do that earlier and thank you so much for your valuable and insightful feedback!

Fabian Degen added 5 commits November 3, 2024 23:36
…ction

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
…when calling uncertainty prediction function

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
…s_with_uncertainty

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
…tion function

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
@degenfabian degenfabian force-pushed the add-prediction-function-with-uncertainty branch from c79aa52 to 8fcf7c4 Compare November 3, 2024 22:36
Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
@paulbkoch paulbkoch merged commit 3fdcab5 into interpretml:develop Nov 4, 2024
57 checks passed
@degenfabian degenfabian deleted the add-prediction-function-with-uncertainty branch November 4, 2024 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

2 participants