Skip to content

Commit

Permalink
ensure a is positive
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Nov 10, 2023
1 parent b1ad63a commit 052ff37
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 3 additions & 0 deletions pymc_experimental/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.logprob.utils import CheckParameterValue
from pymc.pytensorf import floatX
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -329,6 +330,8 @@ def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable:
if rv_size_is_none(size):
size = a.shape

a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0)))

return Chi.dist(nu=3, size=size) * a

def __new__(cls, name, a, **kwargs):
Expand Down
2 changes: 0 additions & 2 deletions pymc_experimental/tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def test_logp(self):
Rplus,
{"a": Rplus},
lambda value, a: sp.maxwell.logpdf(value, scale=a),
skip_paramdomain_outside_edge_test=True,
)

def test_logcdf(self):
Expand All @@ -182,5 +181,4 @@ def test_logcdf(self):
Rplus,
{"a": Rplus},
lambda value, a: sp.maxwell.logcdf(value, scale=a),
skip_paramdomain_outside_edge_test=True,
)

0 comments on commit 052ff37

Please sign in to comment.