diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index a3ec798e1..8379ac2fb 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -75,7 +75,7 @@ def inv(self): return inv def __call__(self, x): - return NotImplementedError + raise NotImplementedError def _inverse(self, y): raise NotImplementedError diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index a43b6268b..53edf234a 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -88,7 +88,7 @@ def mean(self): elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: - return NotImplementedError("mean only available for Normal and Cauchy") + raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): @@ -102,7 +102,7 @@ def var(self): elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: - return NotImplementedError("var only available for Normal and Cauchy") + raise NotImplementedError("var only available for Normal and Cauchy") class RightTruncatedDistribution(Distribution): @@ -152,7 +152,7 @@ def mean(self): elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: - return NotImplementedError("mean only available for Normal and Cauchy") + raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): @@ -166,7 +166,7 @@ def var(self): elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: - return NotImplementedError("var only available for Normal and Cauchy") + raise NotImplementedError("var only available for Normal and Cauchy") class TwoSidedTruncatedDistribution(Distribution): @@ -269,7 +269,7 @@ def mean(self): elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: - return NotImplementedError("mean only available for Normal and Cauchy") + raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): @@ -285,7 +285,7 @@ def var(self): elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: - return NotImplementedError("var only available for Normal and Cauchy") + raise NotImplementedError("var only available for Normal and Cauchy") def TruncatedDistribution(base_dist, low=None, high=None, *, validate_args=None):