Skip to content

Commit

Permalink
Exponential scale default to 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 4, 2024
1 parent 3f3aeb9 commit 394ebe4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
11 changes: 5 additions & 6 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,13 +1373,12 @@ class Exponential(PositiveContinuous):
rv_op = exponential

@classmethod
def dist(cls, lam=None, scale=None, *args, **kwargs):
if lam is not None and scale is not None:
def dist(cls, lam=None, *, scale=None, **kwargs):
if lam is None and scale is None:
scale = 1.0
elif lam is not None and scale is not None:
raise ValueError("Incompatible parametrization. Can't specify both lam and scale.")
elif lam is None and scale is None:
raise ValueError("Incompatible parametrization. Must specify either lam or scale.")

if scale is None:
elif lam is not None:
scale = pt.reciprocal(lam)

scale = pt.as_tensor_variable(scale)
Expand Down
21 changes: 12 additions & 9 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,6 @@ def test_exponential(self):
lambda q, lam: st.expon.ppf(q, loc=0, scale=1 / lam),
)

def test_exponential_wrong_arguments(self):
msg = "Incompatible parametrization. Can't specify both lam and scale"
with pytest.raises(ValueError, match=msg):
pm.Exponential.dist(lam=0.5, scale=5)

msg = "Incompatible parametrization. Must specify either lam or scale"
with pytest.raises(ValueError, match=msg):
pm.Exponential.dist()

def test_laplace(self):
check_logp(
pm.Laplace,
Expand Down Expand Up @@ -2274,8 +2265,20 @@ class TestExponential(BaseTestDistributionRandom):
checks_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_both_lam_scale_raises",
"check_default_scale",
]

def check_both_lam_scale_raises(self):
msg = "Incompatible parametrization. Can't specify both lam and scale"
with pytest.raises(ValueError, match=msg):
pm.Exponential.dist(lam=0.5, scale=5)

def check_default_scale(self):
rv = self.pymc_dist.dist()
[scale] = rv.owner.op.dist_params(rv.owner)
assert scale.data == 1.0


class TestExponentialScale(BaseTestDistributionRandom):
pymc_dist = pm.Exponential
Expand Down

0 comments on commit 394ebe4

Please sign in to comment.