-
-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plot comparisons #684
Plot comparisons #684
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
To plot subplots, I have added an additional argument For example, if a modeller wanted to only consider a given interval for fig, ax = plot_comparison(
model=penguin_model,
idata=penguin_idata,
contrast_predictor=["flipper_length_mm"],
conditional={
"bill_length_mm": np.arange(40, 50, 1),
"species": ["Adelie", "Chinstrap", "Gentoo"]
},
subplot_kwargs={"main": "bill_length_mm", "group": "species", "panel": "species"},
fig_kwargs={"figsize": (10, 3), "sharey": True},
legend=False
) If I followed the same convention as fig, ax = plot_comparison(
model=penguin_model,
idata=penguin_idata,
contrast_predictor=["flipper_length_mm"],
conditional={
"main": {"bill_length_mm": np.arange(40, 50, 1)},
"group": {"species": ["Adelie", "Chinstrap", "Gentoo"]},
"panel": {"species": ["Adelie", "Chinstrap", "Gentoo"]}
}
) Now this raises the question if we should add the same |
Commit 41d4565 adds the ability to compute multi-level contrast comparisons. This is achieved by first computing all pairwise orderings of the contrast value. Then, the For example, # if the user wants to compare > 2 levels. Use the comparison function directly
comparisons(
model=model_interaction,
idata=idata_interaction,
contrast="prog",
conditional="math"
)
To do before moving to a normal PR:
|
cf68fd0
to
a647b27
Compare
In addition to the requested changes by @tomicapretto, the latest commits in this PR added the following functionality:
Here, I give a brief example of fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["livebait"] = fish_data["livebait"].astype("category")
fish_data["camper"] = fish_data["camper"].astype("category")
likelihood = bmb.Likelihood("ZeroInflatedPoisson", params=["mu", "psi"], parent="mu")
links = {"mu": "log", "psi": "logit"}
zip_family = bmb.Family("zip", likelihood, links)
priors = {"psi": bmb.Prior("Beta", alpha=3, beta=3)}
fish_model = bmb.Model(
"count ~ livebait + camper + persons + child",
fish_data,
priors=priors,
family=zip_family
)
fish_idata = fish_model.fit(draws=1000, target_accept=0.95, random_seed=1234, chains=4)
comparisons(
model=fish_model,
idata=fish_idata,
contrast="camper",
conditional=["livebait", "child", "persons"]
)
A user can pass a covariate(s) they would like to average by. For example: # marginalizes over child and persons
comparisons(
model=fish_model,
idata=fish_idata,
contrast="camper",
conditional=["livebait", "child", "persons"],
average_by="livebait"
)
Passing plot_comparison(
model=fish_model,
idata=fish_idata,
contrast="camper",
conditional=["livebait", "child", "persons"],
average_by="livebait"
) I will be adding a notebook for the docs explaining the functionality and how to use |
deleted docs/notebooks/plot_cap.ipynb#
…ontrasts 'average_by=True'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is 95% done. It's super clean and you wrote very high-quality code. Thanks a lot for that!
Most of the changes I'm requesting are "cosmetic" or docstring related changes. The only "big" thing is the updates needed to make sure aliases on non-parent parameters still work. Don't hesitate to ask if you want help here (either to build an example/test or to implement it)
This comment sneaked past me. Thanks a lot, and for the code reviews 👍🏼 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is good to go!
@GStechschulte let me know if I can merge :)
Looks like we can merge 👍🏼 Again, thanks for all the code reviews and insights. Much appreciated! |
This draft PR introduces
plot_comparisons
, a function for comparing the predictions made by a fitted model for different contrasts while holding all other covariates constant or at a user defined value. Inspiration was taken from the great marginaleffects R package.At a high level,
plot_comparisons
allows the modeller to define a contrastcontrast_predictor
and the covariate values to condition onconditional
. If a user does not pass specific values (for either the contrast_predictor or conditional), then a default grid of values is computed. Thus, the comparison in predictions allows a modeller to "see through the eyes of the model", i.e., the comparison on the scale of the outcome. The comparison of predictions is computed using all chains and draws of the posterior.Currently,
plot_comparisons
only allows a user to compare the predictions for 1 contrast level, i.e., how does the probability of survival change if a person moves from 1st to 3rd class givenAge = 50
andSex = [0, 1]
.In the above example, the user defined a value for each covariate. However, default values are computed if the user does not provide any for
conditional
:Another example of default values being computed for a categorical contrast predictor and numerical conditional covariate:
These examples, and further explanations can be found in the following notebook.