Skip to content

Commit

Permalink
Handle multivariate responses with HSGP (#856)
Browse files Browse the repository at this point in the history
* Make HSGP terms aware of multivariate responses

* Make sure two dimensional outputs have two dims

* Remove redundant classes from checks

* Remove prints and add comments

* Remove commented code
  • Loading branch information
tomicapretto authored Dec 16, 2024
1 parent 5d772ff commit 1559a97
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
6 changes: 3 additions & 3 deletions bambi/backend/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build(self, pymc_backend, bmb_model):
self.build_intercept(bmb_model)
self.build_offsets()
self.build_common_terms(pymc_backend, bmb_model)
self.build_hsgp_terms(pymc_backend)
self.build_hsgp_terms(bmb_model, pymc_backend)
self.build_group_specific_terms(pymc_backend, bmb_model)

def build_intercept(self, bmb_model):
Expand Down Expand Up @@ -109,7 +109,7 @@ def build_common_terms(self, pymc_backend, bmb_model):
# Add term to linear predictor
self.output += pt.dot(data, coefs)

def build_hsgp_terms(self, pymc_backend):
def build_hsgp_terms(self, bmb_model, pymc_backend):
"""Add HSGP (Hilbert-Space Gaussian Process approximation) terms to the PyMC model.
The linear predictor 'X @ b + Z @ u' can be augmented with non-parametric HSGP terms
Expand All @@ -120,7 +120,7 @@ def build_hsgp_terms(self, pymc_backend):
for name, values in hsgp_term.coords.items():
if name not in pymc_backend.model.coords:
pymc_backend.model.add_coords({name: values})
self.output += hsgp_term.build()
self.output += hsgp_term.build(bmb_model)

def build_group_specific_terms(self, pymc_backend, bmb_model):
"""Add group-specific (random or varying) terms to the PyMC model
Expand Down
33 changes: 20 additions & 13 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
make_weighted_distribution,
GP_KERNELS,
)
from bambi.families.multivariate import MultivariateFamily, Multinomial, DirichletMultinomial
from bambi.families.multivariate import MultivariateFamily
from bambi.families.univariate import Categorical, Cumulative, StoppingRatio
from bambi.priors import Prior

Expand Down Expand Up @@ -234,22 +234,16 @@ def build(self, pymc_backend, bmb_model):
# Auxiliary parameters and data
kwargs = {"observed": data, "dims": ("__obs__",)}

if isinstance(
self.family,
(
MultivariateFamily,
Categorical,
Cumulative,
StoppingRatio,
Multinomial,
DirichletMultinomial,
),
):
if isinstance(self.family, (MultivariateFamily, Categorical, Cumulative, StoppingRatio)):
response_term = bmb_model.response_component.term
response_name = response_term.alias or response_term.name
dim_name = response_name + "_dim"
pymc_backend.model.add_coords({dim_name: response_term.levels})
dims = ("__obs__", dim_name)

# For multivariate families, the outcome variable has two dimensions too.
if isinstance(self.family, MultivariateFamily):
kwargs["dims"] = dims
else:
dims = ("__obs__",)

Expand Down Expand Up @@ -447,7 +441,7 @@ def __init__(self, term):
if self.term.by_levels is not None:
self.coords[f"{self.term.alias}_by"] = self.coords.pop(f"{self.term.name}_by")

def build(self):
def build(self, spec):
# Get the name of the term
label = self.name

Expand Down Expand Up @@ -507,6 +501,19 @@ def build(self):
phi = phi.eval()

# Build weights coefficient
# Handle the case where the outcome is multivariate
if isinstance(spec.family, (MultivariateFamily, Categorical)):
# Append the dims of the response variables to the coefficient and contribution dims
# In general:
# coeff_dims: ('weights_dim', ) -> ('weights_dim', f'{response}_dim')
# contribution_dims: ('__obs__', ) -> ('__obs__', f'{response}_dim')
response_dims = tuple(spec.response_component.term.coords)
coeff_dims = coeff_dims + response_dims
contribution_dims = contribution_dims + response_dims

# Append a dimension to sqrt_psd: ('weights_dim', ) -> ('weights_dim', 1)
sqrt_psd = sqrt_psd[:, np.newaxis]

if self.term.centered:
coeffs = pm.Normal(f"{label}_weights", sigma=sqrt_psd, dims=coeff_dims)
else:
Expand Down

0 comments on commit 1559a97

Please sign in to comment.