Skip to content

Commit

Permalink
Adding default priors for Binomial/Bernoulli families with logit link (
Browse files Browse the repository at this point in the history
  • Loading branch information
julianlheureux authored Aug 22, 2024
1 parent d574614 commit e7b079d
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 341 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ docs/_build
pytest.ini
/.quarto/
.Rproj.user
# Git ignore all the notebook files in the root
*.ipynb
75 changes: 66 additions & 9 deletions bambi/priors/scaler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import numpy as np
import pymc as pm

from bambi.families.univariate import Cumulative, Gaussian, StoppingRatio, StudentT, VonMises
from bambi.families.univariate import (
Bernoulli,
Binomial,
Cumulative,
Gaussian,
StoppingRatio,
StudentT,
VonMises,
)
from bambi.model_components import ConstantComponent
from bambi.priors.prior import Prior

Expand Down Expand Up @@ -30,8 +38,7 @@ def __init__(self, model):
def get_intercept_stats(self):
mu = self.response_mean
sigma = self.STD * self.response_std

# Only adjust mu and sigma if there is at least one Normal prior for a common term.
# Only adjust sigma if there is at least one Normal prior for a common term.
if self.priors:
sigmas = np.hstack([prior["sigma"] for prior in self.priors.values()])
x_mean = np.hstack(
Expand All @@ -58,23 +65,73 @@ def scale_response(self):
def scale_intercept(self, term):
if term.prior.name != "Normal":
return
mu, sigma = self.get_intercept_stats()
# Special case for logit/probit links with bernoulli or binomial family
if isinstance(self.model.family, (Bernoulli, Binomial)) and self.model.family.link[
"p"
].name in ["logit", "probit"]:
mu = 0
sigma = 1.5
else:
mu, sigma = self.get_intercept_stats()
term.prior.update(mu=mu, sigma=sigma)

def scale_common(self, term):
if term.prior.name != "Normal":
return

# It can be greater than 1 for categorical variables
if term.data.ndim == 1:
mu = 0
sigma = self.get_slope_sigma(term.data)
# Special case for logit/probit links with bernoulli or binomial family
if isinstance(self.model.family, (Bernoulli, Binomial)) and self.model.family.link[
"p"
].name in ["logit", "probit"]:
# For interaction terms, distinguish cases where all factor terms are categorical
if term.kind == "interaction":
all_categoric = all(
component.kind == "categoric" for component in term.term.components
)
if all_categoric:
sigma = 1
else:
sigma = 1 / np.std(term.data, axis=0)
# Single categorical term
elif term.categorical:
sigma = 1
# Single numerical term
else:
sigma = 1 / np.std(term.data, axis=0)
# If not, fall back to the regular case
else:
sigma = self.get_slope_sigma(term.data)
# It's a term that spans multiple columns of the design matrix
else:
mu = np.zeros(term.data.shape[1])
sigma = np.zeros(term.data.shape[1])
# Iterate over columns in the data
for i, value in enumerate(term.data.T):
sigma[i] = self.get_slope_sigma(value)
# Special case for logit/probit links with bernoulli or binomial family
if isinstance(self.model.family, (Bernoulli, Binomial)) and self.model.family.link[
"p"
].name in ["logit", "probit"]:
# Iterate over columns in the data
for i, value in enumerate(term.data.T):
if term.kind == "interaction":
# Distinguish cases where all interaction factor terms are categorical
all_categoric = all(
component.kind == "categoric" for component in term.term.components
)
if all_categoric:
sigma[i] = 1
# It's the std dev of the marginal numerical variable (_not_ by group)
else:
sigma[i] = 1 / np.std(np.sum(term.data, axis=1))
# Single categorical term
elif term.categorical:
sigma[i] = 1
# Single numerical term
else:
sigma[i] = 1 / np.std(term.data, axis=0)
else:
for i, value in enumerate(term.data.T):
sigma[i] = self.get_slope_sigma(value)

# Save and set prior
self.priors.update({term.name: {"mu": mu, "sigma": sigma}})
Expand Down
72 changes: 50 additions & 22 deletions docs/notebooks/alternative_links_binary.ipynb

Large diffs are not rendered by default.

61 changes: 27 additions & 34 deletions docs/notebooks/hierarchical_binomial_bambi.ipynb

Large diffs are not rendered by default.

237 changes: 118 additions & 119 deletions docs/notebooks/logistic_regression.ipynb

Large diffs are not rendered by default.

21 changes: 10 additions & 11 deletions docs/notebooks/plot_comparisons.ipynb

Large diffs are not rendered by default.

288 changes: 142 additions & 146 deletions docs/notebooks/plot_predictions.ipynb

Large diffs are not rendered by default.

0 comments on commit e7b079d

Please sign in to comment.