Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_predictions with random effects #735

Closed
jt-lab opened this issue Oct 12, 2023 · 11 comments · Fixed by #736
Closed

plot_predictions with random effects #735

jt-lab opened this issue Oct 12, 2023 · 11 comments · Fixed by #736

Comments

@jt-lab
Copy link
Contributor

jt-lab commented Oct 12, 2023

@GStechschulte, we have further played around with plot_predictions and came across some behavior we don't understand . It might be an issue with handling random effects, hence I describe it here:

For a model with random effects (p(correct, count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual)), the predictions seem to be off compared to the data points (e.g. see Factor1=A, Factor2=orange, Factor3=1 but also others):

w_rf

Also, in factor C, the hdi bars get a bit smaller. It's barely visible here, but in another data set (which I cannot share), they are about 3 times smaller than those of levels A and B, apparently without any reason related to the data.

The same model but without the random effects produces predictions which are pretty close to the empirical means:

wo_rf

Many thanks in advance!

Code to reproduce this and the dataset:

import pandas as pd
import bambi as bmb
import seaborn as sns

data = pd.read_csv('simulated_data_order_prob2.csv')

priors = {
    "factor3:factor2": bmb.Prior("Normal", mu=0, sigma=1),
    "factor1": bmb.Prior("Normal", mu=0, sigma=1),
    "factor3:factor2|individual": bmb.Prior("Normal", mu=0, sigma=bmb.Prior("Gamma", alpha=3, beta=3))
}

model = bmb.Model(
    "p(correct, count) ~ 0 + factor3:factor2 + factor1  + (0 + factor3:factor2 | individual)",
    data,
    family="binomial",
    categorical=["factor3", "individual", "factor1", "factor2"],
    priors=priors,
    noncentered=False)

idata = model.fit(tune=2000, draws=2000, random_seed=123, init='adapt_diag',
                  target_accept=0.9, idata_kwargs={'log_likelihood':True})

data['report frequency (%)'] = data['correct'] / data['count']

g = sns.catplot(data=data, kind='strip', x='factor3', y='report frequency (%)', 
                hue='factor2', col='factor1', jitter=False, dodge=True)


axs = bmb.interpret.plot_predictions(
    model=model, 
    idata=idata, 
    covariates=["factor3", "factor2", "factor1"],
    pps=False,
    legend=True,
    fig_kwargs={"figsize": (20, 8), "sharey": True},
    prob = .95,
    ax=g.axes
)

Data Set:

simulated_data_order_prob2.csv

@GStechschulte
Copy link
Collaborator

GStechschulte commented Oct 12, 2023

Hey @jt-lab thanks a lot for raising the issue and sharing the code / dataset!

At a quick glance, this is because in the random effects model, individual is included as a term in the Bambi model, and since this term is not specified in plot_predictions, a default value of individual=0 is being computed (since it is categorical, Bambi takes the mode).

bmb.interpret.predictions(
    model=model, 
    idata=idata, 
    covariates=["factor3", "factor2", "factor1"],
)
factor3 factor2 factor1 individual estimate lower_3.0% upper_97.0%
1 X A 0 0.370441 0.331295 0.407852
2 X A 0 0.545766 0.504590 0.586051
1 Y A 0 0.301027 0.268440 0.337534
... ... ... ... ... ... ..
2 X C 0 0.453661 0.426735 0.480137
1 Y C 0 0.229191 0.208711 0.249756
2 Y C 0 0.383008 0.360000 0.407484

Thus, your first plot above seems to be comparing individual=0 predictions with the population level data points. Then, since the second model does not include a random effect (and thus no individual term), the predictions more closely match the population level data points (as you stated).

I have a couple more comments about this, but don't have the time this morning. I will communicate here in the next day. Thanks!

@GStechschulte GStechschulte changed the title plot_predicitons with random effects plot_predictions with random effects Oct 12, 2023
@jt-lab
Copy link
Contributor Author

jt-lab commented Oct 12, 2023

@GStechschulte, thanks a lot! That makes sense. So perhaps we should predict a single out-of-sample individual in this situation.

A quick update: the hdi bar issue described above was unrelated to this...

Looking forward to your further comments!

Many thanks for the support!

@bambinos bambinos deleted a comment from jt-lab Oct 12, 2023
@GStechschulte
Copy link
Collaborator

GStechschulte commented Oct 12, 2023

Of course, anytime! I appreciate getting the feedback. Regarding

the hdi bar issue described above was unrelated to this...

I am looking into this. Thanks for also pointing this out.

plot_predictions was the first interpret plotting function developed in early 2023. Over summer 2023, I added plot_comparisons, and plot_slopes. In the latter two functions, it is possible to predict using the observed (empirical) data. These are "unit-level" predictions. Additionally, it is possible to pass your own values and or to create a grid of values to use as the data fed to the model to perform predictions.

However, plot_predictions does not have this additional functionality yet. In marginaleffects, it does. Thus, once I add the ability to perform unit-level predictions, you could achieve the "desired / correct" plot with plot_predictions. Below, I give you an example in {marginaleffects} using your data and the random effects model:

library(brms)
library(marginaleffects)

dat <- read.csv("simulated_data_order_prob2.csv")

dat$factor3 = as.factor(dat$factor3)
dat$factor2 = as.factor(dat$factor2)
dat$factor1 = as.factor(dat$factor1)
formula <- bf(correct | trials(count) ~ 0 + factor3:factor2 + factor1 + (0 + factor3:factor2 | individual))
model <- brm(formula, data = dat, family = binomial)

# unit-level predictions averaged over individuals
plot_predictions(
  model,
  by=c("factor3", "factor2", "factor1")
)

image

The plot shows the marginal estimates as it was averaged over all individuals.

@tomicapretto what do you think? All the pieces are there in interpret. It is just a matter of putting it together. Then, the predictions and plot_predictions functionality will be "close to" {marginaleffects}. Also, the function calls will be similar to comparisons and slopes resulting in a more standard API.

@GStechschulte
Copy link
Collaborator

Since the majority of the functions are there. Here's a working demo in Bambi now:

model = bmb.Model(
    "p(correct, count) ~ 0 + factor3:factor2 + factor1  + (0 + factor3:factor2 | individual)",
    data,
    family="binomial",
    categorical=["factor3", "individual", "factor1", "factor2"],
    priors=priors,
    noncentered=False
)

idata = model.fit(tune=2000, draws=2000, random_seed=123, init='adapt_diag',
                  target_accept=0.9, idata_kwargs={'log_likelihood':True})

bmb.interpret.plot_predictions(
    model=model, 
    idata=idata,
    average_by=["factor3", "factor2", "factor1"],
    fig_kwargs={"figsize": (12, 4), "sharey": True},
)

image

Side note: I know the {marginaleffects} plot and the Bambi plot aren't the same. I would take the {marginaleffects} plot above with caution as I quickly did that and there were divergences, etc. It was used as an example implementation.

@jt-lab
Copy link
Contributor Author

jt-lab commented Oct 12, 2023

Thank you so much, @GStechschulte!

I just came here to thank you for your previous post with the explanations and examples! But this is of course even greater!
So do I get this right that using the average_by specification instead of the covariates does the trick?

By the way, what do you think of these ideas:

  • Add (optional) verbosity when plots are created. E.g., print out some info when defaults are applied or averaging etc. happens implicitly
  • Add possibility to plot observations along with predictions (like I did manually above).

If you like I could try implementing these!

Many thanks again

@GStechschulte
Copy link
Collaborator

GStechschulte commented Oct 12, 2023

@jt-lab thank you! 😄

So do I get this right that using the average_by specification instead of the covariates does the trick?

Right. Not passing any variables into covariates results in unit-level predictions. Then, since a prediction is made for each individual, a pd.groupby(average_by).mean() is applied to group by the factors (the variables passed to average_by) and then .mean() to compute the marginal effect for: factor3, factor2, factor1.

Add (optional) verbosity when plots are created. E.g., print out some info when defaults are applied or averaging etc. happens implicitly

I have came to realise that unless the user really studies the docs, it is difficult to understand what is all being created and computed. Thus, I do like the idea to be more transparent (optionally). @tomicapretto do you have any thoughts?

Add possibility to plot observations along with predictions (like I did manually above).

I had not thought of this until I saw you do it. At the moment, I would like to limit the amount of plotting code we introduce (Matplotlib is not the most fun to develop and it is difficult to write tests for the content in the plots). Unless more users ask for this, I think I personally won't pursue it. Nonetheless, I liked your solution with seaborn 😄

Thanks for the ideas! 👍🏼

@jt-lab
Copy link
Contributor Author

jt-lab commented Oct 12, 2023

@GStechschulte,

We just wanted to try the average_by solution but there is no argument average_by in plot_predictions. Also I don't see it in the docs or code on github. Even in your fork it's not there. So maybe I misunderstood that this was an already existing workaround? Or is there some secret branch it is on? :-D

At the moment, I would like to limit the amount of plotting code we introduce (Matplotlib is not the most fun to develop and it is difficult to write tests for the content in the plots).

I see, makes sense.

Nonetheless, I liked your solution with seaborn 😄

Yes that works okayish. We had some order-related trouble again as seaborn creates the order depending on the order in the data (if not specified otherwise) and plot_prediction uses a different order. So one has to watch out for that.

@GStechschulte
Copy link
Collaborator

@jt-lab you are too quick for me 😉 haha. I just pushed these changes in this branch.

Please note I still need to add error handling and tests to ensure everything works.

Cheers!

@jt-lab
Copy link
Contributor Author

jt-lab commented Oct 12, 2023

@jt-lab you are too quick for me 😉 haha. I just pushed these changes in this branch.

Please note I still need to add error handling and tests to ensure everything works.

Cheers!

Thanks and sorry for the impatience 😬

I thought you remembered an existing workaround

@GStechschulte
Copy link
Collaborator

@jt-lab I just pushed some more changes to that branch and opened draft PR #736

@tomicapretto
Copy link
Collaborator

@tomicapretto what do you think? All the pieces are there in interpret. It is just a matter of putting it together. Then, the predictions and plot_predictions functionality will be "close to" {marginaleffects}. Also, the function calls will be similar to comparisons and slopes resulting in a more standard API.

Just seeing the issue. Thanks for proactively writing the code and opening the PR :)

I have came to realise that unless the user really studies the docs, it is difficult to understand what is all being created and computed. Thus, I do like the idea to be more transparent (optionally). @tomicapretto do you have any thoughts?

I like the idea too! I only want to make 2 points

  • It should be easy to silent
  • It shouldn't be hard to maintain

I think one possible approach is to have a configuration instance that comes with a default option and users can do something like bmb.config.interpret_messages = False. This is what we have in formulae, check https://github.com/bambinos/formulae/blob/master/formulae/config.py and https://github.com/bambinos/formulae/blob/master/tests/test_config.py

Add possibility to plot observations along with predictions (like I did manually above)

@jt-lab I agree with @GStechschulte's response here. It's already a lot to maintain the existing functionality. Anything related to observed data should be deferred to the user.

@GStechschulte let's continue any related discussion in the PR where you implement the changes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants