Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot comparisons #684

Merged
merged 85 commits into from
Jul 9, 2023
Merged

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Jun 12, 2023

This draft PR introduces plot_comparisons, a function for comparing the predictions made by a fitted model for different contrasts while holding all other covariates constant or at a user defined value. Inspiration was taken from the great marginaleffects R package.

At a high level, plot_comparisons allows the modeller to define a contrast contrast_predictor and the covariate values to condition on conditional. If a user does not pass specific values (for either the contrast_predictor or conditional), then a default grid of values is computed. Thus, the comparison in predictions allows a modeller to "see through the eyes of the model", i.e., the comparison on the scale of the outcome. The comparison of predictions is computed using all chains and draws of the posterior.

Currently, plot_comparisons only allows a user to compare the predictions for 1 contrast level, i.e., how does the probability of survival change if a person moves from 1st to 3rd class given Age = 50 and Sex = [0, 1].

fig, ax = plot_comparison(
    model=titanic_model,
    idata=titanic_idata,
    contrast_predictor={"PClass": [1, 3]},
    conditional={"Age": [50], "SexCode": [0, 1]}
)

image

In the above example, the user defined a value for each covariate. However, default values are computed if the user does not provide any for conditional:

fig, ax = plot_comparison(
    model=titanic_model,
    idata=titanic_idata,
    contrast_predictor={"PClass": [1, 3]},
    conditional=["Age", "SexCode"]
)

image

Another example of default values being computed for a categorical contrast predictor and numerical conditional covariate:

fig, ax, comparisons_df, contrast_df, idata = plot_comparison(
    model=fish_model,
    idata=fish_idata,
    contrast_predictor="livebait",
    conditional="persons"
) 

image

These examples, and further explanations can be found in the following notebook.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Jun 15, 2023

To plot subplots, I have added an additional argument subplot_kwargs where the user can specify the main, group, and panel covariates. I believe this additional arg. is needed because with plot_comparisons, the user can pass their own values for the covariates in conditional as a dict. If we allowed the same subplot figure control as plot_cap, this would require the user to pass a nested dictionary and would result in more code needed to "unnest" the dictionary to access the desired key, value pairs.

For example, if a modeller wanted to only consider a given interval for bill_length_mm and unique species:

fig, ax = plot_comparison(
    model=penguin_model,
    idata=penguin_idata,
    contrast_predictor=["flipper_length_mm"],
    conditional={
        "bill_length_mm": np.arange(40, 50, 1), 
        "species": ["Adelie", "Chinstrap", "Gentoo"]
        },
    subplot_kwargs={"main": "bill_length_mm", "group": "species", "panel": "species"},
    fig_kwargs={"figsize": (10, 3), "sharey": True},
    legend=False
)

image

If I followed the same convention as plot_cap, this would require something like:

fig, ax = plot_comparison(
    model=penguin_model,
    idata=penguin_idata,
    contrast_predictor=["flipper_length_mm"],
    conditional={
        "main": {"bill_length_mm": np.arange(40, 50, 1)}, 
        "group": {"species": ["Adelie", "Chinstrap", "Gentoo"]}, 
        "panel": {"species": ["Adelie", "Chinstrap", "Gentoo"]}
    }
)

Now this raises the question if we should add the same subplot_kwargs to plot_cap to stay consistent?

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Jun 16, 2023

Commit 41d4565 adds the ability to compute multi-level contrast comparisons. This is achieved by first computing all pairwise orderings of the contrast value. Then, the xr.DataArray is indexed using each pair.

For example, contrast=prog has 3 values (General, Vocational, Academic). Each pairwise contrast is shown in the output below.

# if the user wants to compare > 2 levels. Use the comparison function directly
comparisons(
    model=model_interaction,
    idata=idata_interaction,
    contrast="prog",
    conditional="math"
)
math term contrast estimate hdi_3% hdi_97%
1.000000 prog ['Academic', 'General'] 5.705757 -2.780949 15.583213
1.000000 prog ['Academic', 'Vocational'] 5.678390 -2.237369 15.903185
1.000000 prog ['General', 'Vocational'] 5.651256 -2.166023 15.727247
... ... ... ... ... ...
99.000000 prog ['Academic', 'General'] -4.931466 -10.382592 -0.191101
99.000000 prog ['Academic', 'Vocational'] -4.909654 -10.369010 -0.098217
99.000000 prog ['General', 'Vocational'] -4.887939 -10.395197 -0.081602

To do before moving to a normal PR:

  • shape handling for comparisons where a user passes > 1 level
  • allow plotting of subplots (panels)
  • docstrings and type hints
  • optional return of contrast_df (a dataframe containing descriptive statistics of the contrast comparison) (user should use comparisons) if they want a dataframe returned
  • refactor plot_cap code so it works
  • plot comparisons of other model parameters? to be added in a later PR
  • add and run tests, and black

@GStechschulte GStechschulte marked this pull request as ready for review June 19, 2023 18:36
bambi/plots/__init__.py Outdated Show resolved Hide resolved
bambi/plots/create_data.py Outdated Show resolved Hide resolved
bambi/plots/create_data.py Outdated Show resolved Hide resolved
bambi/plots/effects.py Outdated Show resolved Hide resolved
bambi/plots/effects.py Outdated Show resolved Hide resolved
bambi/plots/effects.py Outdated Show resolved Hide resolved
bambi/plots/plot_types.py Show resolved Hide resolved
bambi/plots/plotting.py Outdated Show resolved Hide resolved
bambi/plots/plotting.py Show resolved Hide resolved
bambi/plots/utils.py Outdated Show resolved Hide resolved
@GStechschulte
Copy link
Collaborator Author

In addition to the requested changes by @tomicapretto, the latest commits in this PR added the following functionality:

  • average_by argument in comparisons and plot_comparisons
  • subplot_kwargs in plot_cap to follow the same design as plot_comparisons
  • organised the test_plots.py file into three classes: (1) TestCommon to test common args. of both plot_cap and plot_comparisons; usually regarding Matplotlib figure args, (2) TestCap tests args. specific to plot_cap, and (3) TestComparisons tests args. specific to plot_comparisons

Here, I give a brief example of average_by. For example:

fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["livebait"] = fish_data["livebait"].astype("category")
fish_data["camper"] = fish_data["camper"].astype("category")

likelihood = bmb.Likelihood("ZeroInflatedPoisson", params=["mu", "psi"], parent="mu")
links = {"mu": "log", "psi": "logit"}
zip_family = bmb.Family("zip", likelihood, links)
priors = {"psi": bmb.Prior("Beta", alpha=3, beta=3)}

fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    priors=priors,
    family=zip_family
)

fish_idata = fish_model.fit(draws=1000, target_accept=0.95, random_seed=1234, chains=4)

comparisons(
    model=fish_model,
    idata=fish_idata,
    contrast="camper",
    conditional=["livebait", "child", "persons"]
)
term contrast livebait child persons estimate hdi_0.03% hdi_0.97%
camper (0.0, 1.0) 0.0 0.0 1.0 0.185616 0.086099 0.291901
camper (0.0, 1.0) 0.0 0.0 2.0 0.443242 0.212235 0.679702
camper (0.0, 1.0) 0.0 0.0 4.0 2.542203 1.309796 3.867047
... ... ... ... ... ... ... ...
camper (0.0, 1.0) 1.0 3.0 1.0 0.016145 0.006866 0.025888
camper (0.0, 1.0) 1.0 3.0 2.0 0.038481 0.016707 0.060103
camper (0.0, 1.0) 1.0 3.0 4.0 0.219881 0.112066 0.343412

A user can pass a covariate(s) they would like to average by. For example:

# marginalizes over child and persons
comparisons(
    model=fish_model,
    idata=fish_idata,
    contrast="camper",
    conditional=["livebait", "child", "persons"],
    average_by="livebait"
)
  term contrast livebait estimate hdi_0.03% hdi_0.97%
camper (0.0, 1.0) 0.0 0.445599 0.223418 0.683028
camper (0.0, 1.0) 1.0 2.442100 1.830418 3.058506

Passing livebait to average_by averages by [0, 1] which marginalises over the other covariates child and persons to get the average estimate and uncertainty of livebait. This can also be plotted:

plot_comparison(
    model=fish_model,
    idata=fish_idata,
    contrast="camper",
    conditional=["livebait", "child", "persons"],
    average_by="livebait"
)

image

I will be adding a notebook for the docs explaining the functionality and how to use plot_comparisons in the coming week.

@GStechschulte GStechschulte marked this pull request as draft June 28, 2023 20:26
@GStechschulte GStechschulte marked this pull request as ready for review June 29, 2023 16:46
Copy link
Collaborator

@tomicapretto tomicapretto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is 95% done. It's super clean and you wrote very high-quality code. Thanks a lot for that!

Most of the changes I'm requesting are "cosmetic" or docstring related changes. The only "big" thing is the updates needed to make sure aliases on non-parent parameters still work. Don't hesitate to ask if you want help here (either to build an example/test or to implement it)

bambi/plots/__init__.py Outdated Show resolved Hide resolved
bambi/plots/create_data.py Outdated Show resolved Hide resolved
bambi/plots/create_data.py Show resolved Hide resolved
bambi/plots/effects.py Outdated Show resolved Hide resolved
bambi/plots/effects.py Outdated Show resolved Hide resolved
bambi/plots/effects.py Outdated Show resolved Hide resolved
bambi/plots/plot_types.py Show resolved Hide resolved
bambi/plots/plot_types.py Show resolved Hide resolved
bambi/plots/plotting.py Outdated Show resolved Hide resolved
@GStechschulte
Copy link
Collaborator Author

The PR is 95% done. It's super clean and you wrote very high-quality code. Thanks a lot for that!

Most of the changes I'm requesting are "cosmetic" or docstring related changes. The only "big" thing is the updates needed to make sure aliases on non-parent parameters still work. Don't hesitate to ask if you want help here (either to build an example/test or to implement it)

This comment sneaked past me. Thanks a lot, and for the code reviews 👍🏼

Copy link
Collaborator

@tomicapretto tomicapretto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is good to go!

@GStechschulte let me know if I can merge :)

@GStechschulte
Copy link
Collaborator Author

Looks like this is good to go!

@GStechschulte let me know if I can merge :)

Looks like we can merge 👍🏼 Again, thanks for all the code reviews and insights. Much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants