Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor missing discrete dists #4684

Merged
merged 12 commits into from
May 13, 2021
2 changes: 2 additions & 0 deletions docs/source/api/distributions/discrete.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ Discrete
ZeroInflatedNegativeBinomial
DiscreteUniform
Geometric
HyperGeometric
Categorical
DiscreteWeibull
Constant
OrderedLogistic
OrderedProbit

.. automodule:: pymc3.distributions.discrete
:members:
2 changes: 0 additions & 2 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
Binomial,
Categorical,
Constant,
ConstantDist,
DiscreteUniform,
DiscreteWeibull,
Geometric,
Expand Down Expand Up @@ -138,7 +137,6 @@
"Bernoulli",
"Poisson",
"NegativeBinomial",
"ConstantDist",
"Constant",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial",
Expand Down
8 changes: 5 additions & 3 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,6 @@ class Uniform(BoundedContinuous):
def dist(cls, lower=0, upper=1, **kwargs):
lower = at.as_tensor_variable(floatX(lower))
upper = at.as_tensor_variable(floatX(upper))
# mean = (upper + lower) / 2.0
# median = self.mean
return super().dist([lower, upper], **kwargs)

def logp(value, lower, upper):
Expand All @@ -270,7 +268,11 @@ def logp(value, lower, upper):
-------
TensorVariable
"""
return bound(-at.log(upper - lower), value >= lower, value <= upper)
return bound(
at.fill(value, -at.log(upper - lower)),
value >= lower,
value <= upper,
)

def logcdf(value, lower, upper):
"""
Expand Down
Loading