Skip to content

Commit

Permalink
Don't pass dims to the likelihood distribution (#629)
Browse files Browse the repository at this point in the history
* Add extra dimensions and coordinates to model with fake multivariate response

* move comment

* tuple -> list

* Transform dims after the 'transform_backend_kwargs' is called

* Move logic to its own method

* Don't pass dims to the likelihood function

* update changelog

* Add test
  • Loading branch information
tomicapretto authored Feb 2, 2023
1 parent 9ba92e1 commit 5ad17fd
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
### Maintenance and fixes

* Moved the `tests` directory to the root of the repository (#607)
* Don't pass `dims` to the response of the likelihood distribution anymore (#629)

### Documentation

Expand Down
5 changes: 2 additions & 3 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def build(self, pymc_backend, bmb_model):

# Distributional parameters. A link funciton is used.
response_aliased_name = get_aliased_name(self.term)
dims = (response_aliased_name + "_obs",)
dims = [response_aliased_name + "_obs"]
for name, component in pymc_backend.distributional_components.items():
bmb_component = bmb_model.components[name]
if bmb_component.response_term: # The response is added later
Expand All @@ -251,10 +251,9 @@ def build(self, pymc_backend, bmb_model):
# Take the inverse link function that maps from linear predictor to the parent of likelihood
linkinv = get_linkinv(self.family.link[parent], pymc_backend.INVLINKS)

# Add parent parameter and observed data
# Add parent parameter and observed data. We don't need to pass dims.
kwargs[parent] = linkinv(nu)
kwargs["observed"] = data
kwargs["dims"] = dims

# Build the response distribution
dist = self.build_response_distribution(kwargs)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_built_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import numpy as np
import pandas as pd
import pymc as pm

from bambi import math
from bambi.families import Family, Likelihood, Link
from bambi.models import Model
from bambi.priors import Prior
from bambi.terms import GroupSpecificTerm
Expand Down Expand Up @@ -769,3 +771,31 @@ def test_group_specific_splines():

model = Model("y ~ (bs(x, knots=knots, intercept=False, degree=1)|day)", data=x_check)
model.build()


def test_2d_response_no_shape():
"""
This tests whether a model where there's a single linear predictor and a response with
response.ndim > 1 works well, without Bambi causing any shape problems.
See https://github.com/bambinos/bambi/pull/629
"""

def fn(name, p, observed, **kwargs):
y = observed[:, 0].flatten()
n = observed[:, 1].flatten()
return pm.Binomial(name, p=p, n=n, observed=y, **kwargs)

likelihood = Likelihood("CustomBinomial", params=["p"], parent="p", dist=fn)
link = Link("logit")
family = Family("custom-binomial", likelihood, link)

data = pd.DataFrame(
{
"x": np.array([1.6907, 1.7242, 1.7552, 1.7842, 1.8113, 1.8369, 1.8610, 1.8839]),
"n": np.array([59, 60, 62, 56, 63, 59, 62, 60]),
"y": np.array([6, 13, 18, 28, 52, 53, 61, 60]),
}
)

model = Model("prop(y, n) ~ x", data, family=family)
model.fit(draws=10, tune=10)
5 changes: 2 additions & 3 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,10 @@ def test_multiple_outputs():
y = rng.gamma(shape, np.exp(a + b * x) / shape, N)
data_gamma = pd.DataFrame({"x": x, "y": y})


formula = Formula("y ~ x", "alpha ~ x")
model = Model(formula, data_gamma, family="gamma")
idata = model.fit(tune=100, draws=100, random_seed=1234)
# Test default target
# Test default target
plot_cap(model, idata, "x")
# Test user supplied target argument
plot_cap(model, idata, "x", "alpha")
plot_cap(model, idata, "x", "alpha")

0 comments on commit 5ad17fd

Please sign in to comment.