Skip to content

Commit

Permalink
Transform dims after the 'transform_backend_kwargs' is called
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Jan 26, 2023
1 parent 2322c86 commit d1cc7f7
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,7 @@ def build(self, pymc_backend, bmb_model):

# Distributional parameters. A link funciton is used.
response_aliased_name = get_aliased_name(self.term)
dims = [
response_aliased_name + "_obs",
]
dims = [response_aliased_name + "_obs"]
for name, component in pymc_backend.distributional_components.items():
bmb_component = bmb_model.components[name]
if bmb_component.response_term: # The response is added later
Expand All @@ -256,27 +254,14 @@ def build(self, pymc_backend, bmb_model):
# Add parent parameter and observed data
kwargs[parent] = linkinv(nu)
kwargs["observed"] = data

dims_n = len(dims)
ndim_diff = data.ndim - dims_n

# The response has multiple variables, but a single linear predictor
if ndim_diff > 0:
for i in range(ndim_diff):
axis = dims_n + i
name = f"{response_aliased_name}_extra_dim_{i}"
values = np.arange(np.size(data, axis=axis))
pymc_backend.model.add_coords({name: values})
dims.append(name)

kwargs["dims"] = dims

# Build the response distribution
dist = self.build_response_distribution(kwargs)
dist = self.build_response_distribution(kwargs, pymc_backend)

return dist

def build_response_distribution(self, kwargs):
def build_response_distribution(self, kwargs, pymc_backend):
# Get likelihood distribution
if self.family.likelihood.dist:
dist = self.family.likelihood.dist
Expand All @@ -288,6 +273,24 @@ def build_response_distribution(self, kwargs):
if hasattr(self.family, "transform_backend_kwargs"):
kwargs = self.family.transform_backend_kwargs(kwargs)

# It's possible the observed for the response is multidimensional, but there's a single
# linear predictor because the family is not multivariate.
# In this case, we add extra dimensions to avoid having shape mismatch
response_aliased_name = get_aliased_name(self.term)
dims, data = kwargs["dims"], kwargs["observed"]
dims_n = len(dims)
ndim_diff = data.ndim - dims_n

if ndim_diff > 0:
for i in range(ndim_diff):
axis = dims_n + i
name = f"{response_aliased_name}_extra_dim_{i}"
values = np.arange(np.size(data, axis=axis))
pymc_backend.model.add_coords({name: values})
dims.append(name)

kwargs["dims"] = dims

return dist(self.name, **kwargs)

@property
Expand Down

0 comments on commit d1cc7f7

Please sign in to comment.