Skip to content

Commit

Permalink
Use number of rows from out-of-sample data in multivariate families (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Sep 26, 2024
1 parent 516d7bd commit 46d5572
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
43 changes: 39 additions & 4 deletions bambi/families/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ def transform_coords(self, model, mean):
return mean

def posterior_predictive(self, model, posterior, **kwargs):
n = model.response_component.term.data.sum(1).astype(int)
data = kwargs["data"]
if data is None:
y = model.response_component.term.data
trials = model.response_component.term.data.sum(1).astype(int)
else:
y = response_evaluate_new_data(model, data).astype(int)
trials = y.sum(1).astype(int)

# Prepend 'draw' and 'chain' dimensions
trials = trials[np.newaxis, np.newaxis, :]
dont_reshape = ["n"]
return super().posterior_predictive(model, posterior, n=n, dont_reshape=dont_reshape)
return super().posterior_predictive(model, posterior, n=trials, dont_reshape=dont_reshape)

def log_likelihood(self, model, posterior, data, **kwargs):
if data is None:
Expand Down Expand Up @@ -91,9 +100,35 @@ class DirichletMultinomial(MultivariateFamily):
SUPPORTED_LINKS = {"a": ["log"]}

def posterior_predictive(self, model, posterior, **kwargs):
n = model.response_component.term.data.sum(1).astype(int)
data = kwargs["data"]
if data is None:
y = model.response_component.term.data
trials = model.response_component.term.data.sum(1).astype(int)
else:
y = response_evaluate_new_data(model, data).astype(int)
trials = y.sum(1).astype(int)

# Prepend 'draw' and 'chain' dimensions
trials = trials[np.newaxis, np.newaxis, :]
dont_reshape = ["n"]
return super().posterior_predictive(model, posterior, n=n, dont_reshape=dont_reshape)
return super().posterior_predictive(model, posterior, n=trials, dont_reshape=dont_reshape)

def log_likelihood(self, model, posterior, data, **kwargs):
if data is None:
y = model.response_component.term.data
trials = model.response_component.term.data.sum(1).astype(int)
else:
y = response_evaluate_new_data(model, data).astype(int)
trials = y.sum(1).astype(int)

# Prepend 'draw' and 'chain' dimensions
y = y[np.newaxis, np.newaxis, :]
trials = trials[np.newaxis, np.newaxis, :]

dont_reshape = ["n"]
return super().log_likelihood(
model, posterior, data=None, y=y, n=trials, dont_reshape=dont_reshape, **kwargs
)

def get_coords(self, response):
name = get_aliased_name(response) + "_dim"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,10 @@ def test_intercept_only(self, multinomial_data):
idata = self.predict_oos(model, idata, data=model.data)
self.assert_posterior_predictive(model, idata)

# Out of sample with different number of rows, see issue #845
idata = self.predict_oos(model, idata, data=model.data.sample(frac=0.8, random_state=1211))
self.assert_posterior_predictive(model, idata)

def test_numerical_predictors(self, multinomial_data):
model = bmb.Model(
"c(y1, y2, y3, y4) ~ treat + carry", multinomial_data, family="multinomial"
Expand Down Expand Up @@ -1242,6 +1246,10 @@ def test_intercept_only(self, multinomial_data):
idata = self.predict_oos(model, idata, model.data)
self.assert_posterior_predictive(model, idata)

# Out of sample with different number of rows, see issue #845
idata = self.predict_oos(model, idata, data=model.data.sample(frac=0.8, random_state=1211))
self.assert_posterior_predictive(model, idata)

def test_predictor(self, multinomial_data):
model = bmb.Model(
"c(y1, y2, y3, y4) ~ 0 + treat", multinomial_data, family="dirichlet_multinomial"
Expand Down

0 comments on commit 46d5572

Please sign in to comment.