Skip to content

Commit

Permalink
Check if there is an attribute before trying to access it (#851)
Browse files Browse the repository at this point in the history
* Check if there is an attribute before trying to access it

* Code format

* Pin versions of jax dependencies

* Make ordinal families work with intercept only models
  • Loading branch information
tomicapretto authored Nov 9, 2024
1 parent 46d5572 commit 7a18fb9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
22 changes: 20 additions & 2 deletions bambi/families/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,17 @@ def transform_backend_eta(eta, kwargs):
# shape(threshold) = (K, )
# shape(eta) = (n, )
# shape(threshold - shape_padright(eta)) = (n, K)

threshold = kwargs["threshold"]
eta_shifted = threshold - pt.shape_padright(eta)

# When the model does not have any predictors.
# Inference can be slower, as this can potentially build a larger object.
# However, this is needed for consistency with other parts of the codebase
if eta == 0:
eta_shifted = threshold - pt.shape_padright(pt.zeros(len(kwargs["observed"])))
else:
eta_shifted = threshold - pt.shape_padright(eta)

return eta_shifted

@staticmethod
Expand Down Expand Up @@ -393,8 +402,17 @@ def transform_backend_eta(eta, kwargs):
# shape(threshold) = (K, )
# shape(eta) = (n, )
# shape(threshold - shape_padright(eta)) = (n, K)

threshold = kwargs["threshold"]
eta_shifted = threshold - pt.shape_padright(eta)

# When the model does not have any predictors.
# Inference can be slower, as this can potentially build a larger object.
# However, this is needed for consistency with other parts of the codebase
if eta == 0:
eta_shifted = threshold - pt.shape_padright(pt.zeros(len(kwargs["observed"])))
else:
eta_shifted = threshold - pt.shape_padright(eta)

return eta_shifted

@staticmethod
Expand Down
12 changes: 10 additions & 2 deletions bambi/priors/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ def scale_response(self):
# Here we would add cases for other families if we wanted
if isinstance(self.model.family, (Gaussian, StudentT)):
sigma = self.model.components["sigma"]
if isinstance(sigma, ConstantComponent) and sigma.prior.auto_scale:
if (
isinstance(sigma, ConstantComponent)
and hasattr(sigma.prior, "auto_scale") # not available when `.prior` is a scalar
and sigma.prior.auto_scale
):
sigma.prior = Prior("HalfStudentT", nu=4, sigma=self.response_std)
elif isinstance(self.model.family, VonMises):
kappa = self.model.components["kappa"]
if isinstance(kappa, ConstantComponent) and kappa.prior.auto_scale:
if (
isinstance(kappa, ConstantComponent)
and hasattr(kappa.prior, "auto_scale") # not available when `.prior` is a scalar
and kappa.prior.auto_scale
):
kappa.prior = Prior("HalfStudentT", nu=4, sigma=self.response_std)

def scale_intercept(self, term):
Expand Down
22 changes: 13 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"formulae>=0.5.3",
"graphviz",
"pandas>=1.0.0",
"pymc>=5.16.1",
"pymc>=5.18.0",
]

[project.optional-dependencies]
Expand All @@ -38,8 +38,12 @@ dev = [
"seaborn>=0.9.0",
]

# TODO: Unpin this before making a release
jax = [
"bayeux-ml>=0.1.13",
"bayeux-ml==0.1.14",
"blackjax==1.2.3",
"jax<=0.4.33",
"jaxlib<=0.4.33",
]

[project.urls]
Expand All @@ -50,14 +54,14 @@ changelog = "https://github.com/bambinos/bambi/blob/main/docs/CHANGELOG.md"

[tool.setuptools]
packages = [
"bambi",
"bambi.backend",
"bambi.data",
"bambi.defaults",
"bambi",
"bambi.backend",
"bambi.data",
"bambi.defaults",
"bambi.families",
"bambi.interpret",
"bambi.priors",
"bambi.terms",
"bambi.interpret",
"bambi.priors",
"bambi.terms",
]

[tool.black]
Expand Down

0 comments on commit 7a18fb9

Please sign in to comment.