diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index feb9d023..514fc4df 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -501,13 +501,17 @@ def build(self, spec): phi = phi.eval() # Build weights coefficient - # FIXME: this is a hot-fix, not sure if this is what we want it to do - # Dims of the response variable - response_dims = None + # Handle the case where the outcome is multivariate if isinstance(spec.family, (MultivariateFamily, Categorical)): + # Append the dims of the response variables to the coefficient and contribution dims + # In general: + # coeff_dims: ('weights_dim', ) -> ('weights_dim', f'{response}_dim') + # contribution_dims: ('__obs__', ) -> ('__obs__', f'{response}_dim') response_dims = tuple(spec.response_component.term.coords) coeff_dims = coeff_dims + response_dims contribution_dims = contribution_dims + response_dims + + # Append a dimension to sqrt_psd: ('weights_dim', ) -> ('weights_dim', 1) sqrt_psd = sqrt_psd[:, np.newaxis] if self.term.centered: @@ -516,9 +520,6 @@ def build(self, spec): coeffs_raw = pm.Normal(f"{label}_weights_raw", dims=coeff_dims) coeffs = pm.Deterministic(f"{label}_weights", coeffs_raw * sqrt_psd, dims=coeff_dims) - print("sqrt_psd", sqrt_psd.shape.eval()) - print("coeffs", coeffs.shape.eval()) - # Build deterministic for the HSGP contribution # If there are groups, we do as many dot products as groups if self.term.by_levels is not None: @@ -531,9 +532,6 @@ def build(self, spec): else: contribution = pt.dot(phi, coeffs) # "@" operator is not working as expected - print("coeffs", coeffs.shape.eval()) - print("phi", phi.shape) - print("contribution", contribution.shape.eval()) output = pm.Deterministic(label, contribution, dims=contribution_dims) return output