Skip to content

Commit

Permalink
Move logic to its own method
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Jan 26, 2023
1 parent d1cc7f7 commit a4e3646
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,21 @@ def build_response_distribution(self, kwargs, pymc_backend):
if hasattr(self.family, "transform_backend_kwargs"):
kwargs = self.family.transform_backend_kwargs(kwargs)

kwargs = self.robustify_dims(pymc_backend, kwargs)

return dist(self.name, **kwargs)

@property
def name(self):
if self.term.alias:
return self.term.alias
return self.term.name

def robustify_dims(self, pymc_backend, kwargs):
# It's possible the observed for the response is multidimensional, but there's a single
# linear predictor because the family is not multivariate.
# In this case, we add extra dimensions to avoid having shape mismatch
# In this case, we add extra dimensions to avoid having shape mismatch between the data
# and the shape implied by the `dims` we pass.
response_aliased_name = get_aliased_name(self.term)
dims, data = kwargs["dims"], kwargs["observed"]
dims_n = len(dims)
Expand All @@ -288,16 +300,8 @@ def build_response_distribution(self, kwargs, pymc_backend):
values = np.arange(np.size(data, axis=axis))
pymc_backend.model.add_coords({name: values})
dims.append(name)

kwargs["dims"] = dims

return dist(self.name, **kwargs)

@property
def name(self):
if self.term.alias:
return self.term.alias
return self.term.name
return kwargs


def get_linkinv(link, invlinks):
Expand Down

0 comments on commit a4e3646

Please sign in to comment.