Skip to content

Commit dc20e3e

Browse files
committed
add support for priors from data
1 parent 91aee00 commit dc20e3e

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

causalpy/pymc_models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class PyMCModel(pm.Model):
7373
def default_priors(self):
7474
return {}
7575

76+
def priors_from_data(self, X, y) -> Dict[str, Any]:
77+
return {}
78+
7679
def __init__(
7780
self,
7881
sample_kwargs: Optional[Dict[str, Any]] = None,
@@ -122,6 +125,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
122125
# sample_posterior_predictive() if provided in sample_kwargs.
123126
random_seed = self.sample_kwargs.get("random_seed", None)
124127

128+
self.priors = {**self.priors_from_data(X, y), **self.priors}
129+
125130
self.build_model(X, y, coords)
126131
with self:
127132
self.idata = pm.sample(**self.sample_kwargs)
@@ -295,16 +300,22 @@ class WeightedSumFitter(PyMCModel):
295300
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"),
296301
}
297302

303+
def priors_from_data(self, X, y) -> Dict[str, Any]:
304+
n_predictors = X.shape[1]
305+
306+
return {
307+
"beta": Prior("Dirichlet", a=np.ones(n_predictors), dims="coeffs"),
308+
}
309+
298310
def build_model(self, X, y, coords):
299311
"""
300312
Defines the PyMC model
301313
"""
302314
with self:
303315
self.add_coords(coords)
304-
n_predictors = X.shape[1]
305316
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
306317
y = pm.Data("y", y[:, 0], dims="obs_ind")
307-
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
318+
beta = self.priors["beta"].create_variable("beta")
308319
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
309320
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
310321

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)