From 7a18fb9afc5b485dcd95f1a421bbd77586106a2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Sat, 9 Nov 2024 13:34:31 -0300 Subject: [PATCH] Check if there is an attribute before trying to access it (#851) * Check if there is an attribute before trying to access it * Code format * Pin versions of jax dependencies * Make ordinal families work with intercept only models --- bambi/families/univariate.py | 22 ++++++++++++++++++++-- bambi/priors/scaler.py | 12 ++++++++++-- pyproject.toml | 22 +++++++++++++--------- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/bambi/families/univariate.py b/bambi/families/univariate.py index ce0f9ee9..0eca7d0a 100644 --- a/bambi/families/univariate.py +++ b/bambi/families/univariate.py @@ -220,8 +220,17 @@ def transform_backend_eta(eta, kwargs): # shape(threshold) = (K, ) # shape(eta) = (n, ) # shape(threshold - shape_padright(eta)) = (n, K) + threshold = kwargs["threshold"] - eta_shifted = threshold - pt.shape_padright(eta) + + # When the model does not have any predictors. + # Inference can be slower, as this can potentially build a larger object. + # However, this is needed for consistency with other parts of the codebase + if eta == 0: + eta_shifted = threshold - pt.shape_padright(pt.zeros(len(kwargs["observed"]))) + else: + eta_shifted = threshold - pt.shape_padright(eta) + return eta_shifted @staticmethod @@ -393,8 +402,17 @@ def transform_backend_eta(eta, kwargs): # shape(threshold) = (K, ) # shape(eta) = (n, ) # shape(threshold - shape_padright(eta)) = (n, K) + threshold = kwargs["threshold"] - eta_shifted = threshold - pt.shape_padright(eta) + + # When the model does not have any predictors. + # Inference can be slower, as this can potentially build a larger object. + # However, this is needed for consistency with other parts of the codebase + if eta == 0: + eta_shifted = threshold - pt.shape_padright(pt.zeros(len(kwargs["observed"]))) + else: + eta_shifted = threshold - pt.shape_padright(eta) + return eta_shifted @staticmethod diff --git a/bambi/priors/scaler.py b/bambi/priors/scaler.py index fc03c6b8..149d9e67 100644 --- a/bambi/priors/scaler.py +++ b/bambi/priors/scaler.py @@ -55,11 +55,19 @@ def scale_response(self): # Here we would add cases for other families if we wanted if isinstance(self.model.family, (Gaussian, StudentT)): sigma = self.model.components["sigma"] - if isinstance(sigma, ConstantComponent) and sigma.prior.auto_scale: + if ( + isinstance(sigma, ConstantComponent) + and hasattr(sigma.prior, "auto_scale") # not available when `.prior` is a scalar + and sigma.prior.auto_scale + ): sigma.prior = Prior("HalfStudentT", nu=4, sigma=self.response_std) elif isinstance(self.model.family, VonMises): kappa = self.model.components["kappa"] - if isinstance(kappa, ConstantComponent) and kappa.prior.auto_scale: + if ( + isinstance(kappa, ConstantComponent) + and hasattr(kappa.prior, "auto_scale") # not available when `.prior` is a scalar + and kappa.prior.auto_scale + ): kappa.prior = Prior("HalfStudentT", nu=4, sigma=self.response_std) def scale_intercept(self, term): diff --git a/pyproject.toml b/pyproject.toml index 596f9801..7d03b483 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", - "pymc>=5.16.1", + "pymc>=5.18.0", ] [project.optional-dependencies] @@ -38,8 +38,12 @@ dev = [ "seaborn>=0.9.0", ] +# TODO: Unpin this before making a release jax = [ - "bayeux-ml>=0.1.13", + "bayeux-ml==0.1.14", + "blackjax==1.2.3", + "jax<=0.4.33", + "jaxlib<=0.4.33", ] [project.urls] @@ -50,14 +54,14 @@ changelog = "https://github.com/bambinos/bambi/blob/main/docs/CHANGELOG.md" [tool.setuptools] packages = [ - "bambi", - "bambi.backend", - "bambi.data", - "bambi.defaults", + "bambi", + "bambi.backend", + "bambi.data", + "bambi.defaults", "bambi.families", - "bambi.interpret", - "bambi.priors", - "bambi.terms", + "bambi.interpret", + "bambi.priors", + "bambi.terms", ] [tool.black]