Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpret predictions enhancements #736

Merged
merged 8 commits into from
Oct 20, 2023

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Oct 13, 2023

This PR adds new functionality to the predictions and plot_predictions functions in interpret and resolves #735. Users can now

  • compute unit-level predictions
  • compute predictions on user-provided values

Previously, users could only pass a string or list of covariates to compute conditional adjusted predictions. Now, predictions has "most of" the functionality that {marginaleffects} has. Additionally, the changes result in a more standard API when calling comparisons, predictions, and slopes. Each function has the arg. conditional in which the user can "condition" their estimates on. Furthermore, each function can now compute:

  • unit-level (user passes None)
  • user-provided (user passes a dictionary)
  • grid of values (user passes a string or list of covariates)

Below, you will find a couple demos:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings

import bambi as bmb

data = bmb.load_data('mtcars')
data["cyl"] = data["cyl"].replace({4: "low", 6: "medium", 8: "high"})
data["gear"] = data["gear"].replace({3: "A", 4: "B", 5: "C"})
data["cyl"] = pd.Categorical(data["cyl"], categories=["low", "medium", "high"], ordered=True)

model = bmb.Model("mpg ~ 0 + hp * wt + cyl + gear", data)
idata = model.fit(draws=1000, target_accept=0.95, random_seed=1234)

# unit-level predictions
bmb.interpret.predictions(
    model,
    idata,
    conditional=None
)
cyl gear hp wt estimate lower_3.0% upper_97.0%
medium B 110 2.620 22.233424 20.051966 24.476544
medium B 110 2.875 21.320402 19.196886 23.344765
low B 93 2.320 25.901435 24.255648 27.558330
medium A 110 3.215 18.751708 16.259737 21.293185
high A 175 3.440 16.908354 15.261489 18.666662

Unit level predictions and average over hp and wt to obtain marginal effects of gear and cyl:

bmb.interpret.plot_predictions(
    model,
    idata,
    conditional=None,
    average_by=["gear", "cyl"],
    fig_kwargs={"figsize": (7, 3)},
);

image

Compute a pairwise grid using user-provided values and compute predictions:

bmb.interpret.plot_predictions(
    model,
    idata,
    conditional={
        "hp": [100, 120],
        "cyl": ["low", "medium", "high"],
        "gear": "A",
    },
    subplot_kwargs={"main": "hp", "group": "gear", "panel": "cyl"},
    fig_kwargs={"figsize": (10, 4), "sharey": True},
    legend=True
);

image

To do:

  • finish updating docs that use plot_predictions
  • review the added test cases in plot_predictions

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@GStechschulte GStechschulte changed the title interpret predictions enhancements interpret predictions enhancements Oct 13, 2023
@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Oct 13, 2023

@tomicapretto although it is ~660 lines of code added 😵‍💫 a lot of it comes from docstring additions, more error handling, and tests. Most of the added functionality leveraged existing functions.

@tomicapretto
Copy link
Collaborator

tomicapretto commented Oct 18, 2023

@GStechschulte being 100% honest I can't understand all the changes as well as you do. I do see in many cases you made things more general and you are reusing code more, which is great. So once you finish the implementation, I trust your judgment to merge this.

Two more comments:

  1. The test failure is because of some non-compatible type hint on Python 3.9. I _think it's fixed here but I'm not sure.
  2. Do you think we can cut a 0.13.0 release after this? I just realized we don't have a release with the interpret submodule and I think it's quite mature at this point.

@GStechschulte
Copy link
Collaborator Author

being 100% honest I can't understand all the changes as well as you do

Mmm. That's not a good sign in my opinion. Is it the code diff that you are unsure about, or what has changed in the predictions functionality?

The test failure is because of some non-compatible type hint on Python 3.9. I _think it's fixed pymc-devs/pymc#6945 but I'm not sure.

It has been resolved.

Do you think we can cut a 0.13.0 release after this? I just realized we don't have a release with the interpret submodule and I think it's quite mature at this point.

Yup, I think we can go ahead with that 👍🏼

@tomicapretto
Copy link
Collaborator

Mmm. That's not a good sign in my opinion. Is it the code diff that you are unsure about, or what has changed in the predictions functionality?

Oh no, I don't mean this in a bad way at all. What I'm saying is that the submodule grew a lot, for good reasons, and I'm not as familiar with everything as you are. So I can only provide a high-level review without getting deep into details because it would take much more time.

@GStechschulte GStechschulte marked this pull request as ready for review October 18, 2023 16:25
@tomicapretto
Copy link
Collaborator

Looks great! Go ahead and merge if you don't plan to add anything.

@GStechschulte GStechschulte merged commit 9b6bec4 into bambinos:main Oct 20, 2023
1 of 4 checks passed
@jt-lab
Copy link
Contributor

jt-lab commented Oct 20, 2023

@GStechschulte, many thanks for the work on the submodule, it really helps a lot! I can report that the preliminary average_by solutions already worked very well not only on the simulated data set but also on others. The predictions align well with the observed data.

@GStechschulte
Copy link
Collaborator Author

@jt-lab Thank you and this is great to hear! 😄 We / I really appreciate you taking the time to open the issues and to give feedback. Cheers!

@GStechschulte GStechschulte deleted the predictions-enhancements branch January 21, 2024 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

plot_predictions with random effects
3 participants