Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Feb 1, 2023
1 parent ea1a28c commit 27bd9ea
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
30 changes: 30 additions & 0 deletions tests/test_built_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import numpy as np
import pandas as pd
import pymc as pm

from bambi import math
from bambi.families import Family, Likelihood, Link
from bambi.models import Model
from bambi.priors import Prior
from bambi.terms import GroupSpecificTerm
Expand Down Expand Up @@ -769,3 +771,31 @@ def test_group_specific_splines():

model = Model("y ~ (bs(x, knots=knots, intercept=False, degree=1)|day)", data=x_check)
model.build()


def test_2d_response_no_shape():
"""
This tests whether a model where there's a single linear predictor and a response with
response.ndim > 1 works well, without Bambi causing any shape problems.
See https://github.com/bambinos/bambi/pull/629
"""

def fn(name, p, observed, **kwargs):
y = observed[:, 0].flatten()
n = observed[:, 1].flatten()
return pm.Binomial(name, p=p, n=n, observed=y, **kwargs)

likelihood = Likelihood("CustomBinomial", params=["p"], parent="p", dist=fn)
link = Link("logit")
family = Family("custom-binomial", likelihood, link)

data = pd.DataFrame(
{
"x": np.array([1.6907, 1.7242, 1.7552, 1.7842, 1.8113, 1.8369, 1.8610, 1.8839]),
"n": np.array([59, 60, 62, 56, 63, 59, 62, 60]),
"y": np.array([6, 13, 18, 28, 52, 53, 61, 60]),
}
)

model = Model("prop(y, n) ~ x", data, family=family)
model.fit(draws=10, tune=10)
5 changes: 2 additions & 3 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,10 @@ def test_multiple_outputs():
y = rng.gamma(shape, np.exp(a + b * x) / shape, N)
data_gamma = pd.DataFrame({"x": x, "y": y})


formula = Formula("y ~ x", "alpha ~ x")
model = Model(formula, data_gamma, family="gamma")
idata = model.fit(tune=100, draws=100, random_seed=1234)
# Test default target
# Test default target
plot_cap(model, idata, "x")
# Test user supplied target argument
plot_cap(model, idata, "x", "alpha")
plot_cap(model, idata, "x", "alpha")

0 comments on commit 27bd9ea

Please sign in to comment.