Skip to content

Commit

Permalink
Address review feedback
Browse files Browse the repository at this point in the history
- remove upper and lower checks
- pass default value for dist as `testval` to Distribution dist method
  • Loading branch information
matteo-pallini committed Jun 13, 2021
1 parent de44767 commit e5cc1f4
Showing 1 changed file with 44 additions and 52 deletions.
96 changes: 44 additions & 52 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,25 +671,39 @@ def dist(
*args,
**kwargs,
):
mu, sigma = _truncated_normal_prepare_mu_sigma_params(mu, sigma, tau, sd)
lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper(
lower, upper
)

if lower_check is None and upper_check is None:
_defaultval = mu
elif lower_check is None and upper_check is not None:
_defaultval = upper - 1.0
elif lower_check is not None and upper_check is None:
_defaultval = lower + 1.0
sigma = sd if sd is not None else sigma
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
sigma = at.as_tensor_variable(sigma)
tau = at.as_tensor_variable(tau)
mu = at.as_tensor_variable(floatX(mu))
assert_negative_support(sigma, "sigma", "TruncatedNormal")
assert_negative_support(tau, "tau", "TruncatedNormal")

if lower is None and upper is None:
initval = mu
elif lower is None and upper is not None:
initval = upper - 1.0
elif lower is not None and upper is None:
initval = lower + 1.0
else:
_defaultval = (lower + upper) / 2
initval = (lower + upper) / 2

return super().dist([mu, sigma, lower, upper], **kwargs)
lower = (
at.as_tensor_variable(floatX(lower))
if lower is not None
else at.as_tensor_variable(-np.inf)
)
upper = (
at.as_tensor_variable(floatX(upper))
if upper is not None
else at.as_tensor_variable(np.inf)
)
res = super().dist([mu, sigma, lower, upper], testval=initval, **kwargs)
return res

def logp(value, mu, sigma, lower, upper):
"""
Calculate log-probability of TruncatedNormal distribution at specified value.
Calculate log-probaDirichletMultinomialbility of TruncatedNormal distribution at specified value.
Parameters
----------
Expand All @@ -701,54 +715,32 @@ def logp(value, mu, sigma, lower, upper):
-------
TensorVariable
"""
mu, sigma = _truncated_normal_prepare_mu_sigma_params(mu, sigma, tau=None, sd=None)
print(lower.eval())
lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper(
lower, upper
)
print(lower.eval())
norm = _truncated_normal_normalization(mu, sigma, lower, upper)

if lower is not None and upper is not None:
lcdf_a = normal_lcdf(mu, sigma, lower)
lcdf_b = normal_lcdf(mu, sigma, upper)
lsf_a = normal_lccdf(mu, sigma, lower)
lsf_b = normal_lccdf(mu, sigma, upper)
norm = at.switch(lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a))
elif lower is not None:
norm = normal_lccdf(mu, sigma, lower)
elif upper is not None:
norm = normal_lcdf(mu, sigma, upper)
else:
norm = 0.0

logp = Normal.logp(value, mu=mu, sigma=sigma) - norm
bounds = [sigma > 0]
if lower_check is not None:
if lower is not None:
bounds.append(value >= lower)
if upper_check is not None:
if upper is not None:
bounds.append(value <= upper)
print([e.eval() for e in bounds])
return bound(logp, *bounds)

def _distr_parameters_for_repr(self):
return ["mu", "sigma", "lower", "upper"]


def _truncated_normal_prepare_mu_sigma_params(mu, sigma, tau, sd):
sigma = sd if sd is not None else sigma
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
sigma = at.as_tensor_variable(sigma)
tau = at.as_tensor_variable(tau)
mu = at.as_tensor_variable(floatX(mu))
assert_negative_support(sigma, "sigma", "TruncatedNormal")
assert_negative_support(tau, "tau", "TruncatedNormal")
return mu, sigma


def _truncated_normal_normalization(mu, sigma, lower, upper):
lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper(lower, upper)
if lower_check is None and upper_check is None:
return 0.0

if lower_check is not None and upper_check is not None:
lcdf_a = normal_lcdf(mu, sigma, lower)
lcdf_b = normal_lcdf(mu, sigma, upper)
lsf_a = normal_lccdf(mu, sigma, lower)
lsf_b = normal_lccdf(mu, sigma, upper)
return at.switch(lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a))
if lower_check is not None:
return normal_lccdf(mu, sigma, lower)
else:
return normal_lcdf(mu, sigma, upper)


def _truncated_normal_prepare_lower_and_upper(lower, upper):
lower_check = at.as_tensor_variable(floatX(lower)) if lower is not None else lower
upper_check = at.as_tensor_variable(floatX(upper)) if upper is not None else upper
Expand Down

0 comments on commit e5cc1f4

Please sign in to comment.