diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 174c913acf7..06062849ba1 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -671,25 +671,39 @@ def dist( *args, **kwargs, ): - mu, sigma = _truncated_normal_prepare_mu_sigma_params(mu, sigma, tau, sd) - lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper( - lower, upper - ) - - if lower_check is None and upper_check is None: - _defaultval = mu - elif lower_check is None and upper_check is not None: - _defaultval = upper - 1.0 - elif lower_check is not None and upper_check is None: - _defaultval = lower + 1.0 + sigma = sd if sd is not None else sigma + tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) + sigma = at.as_tensor_variable(sigma) + tau = at.as_tensor_variable(tau) + mu = at.as_tensor_variable(floatX(mu)) + assert_negative_support(sigma, "sigma", "TruncatedNormal") + assert_negative_support(tau, "tau", "TruncatedNormal") + + if lower is None and upper is None: + initval = mu + elif lower is None and upper is not None: + initval = upper - 1.0 + elif lower is not None and upper is None: + initval = lower + 1.0 else: - _defaultval = (lower + upper) / 2 + initval = (lower + upper) / 2 - return super().dist([mu, sigma, lower, upper], **kwargs) + lower = ( + at.as_tensor_variable(floatX(lower)) + if lower is not None + else at.as_tensor_variable(-np.inf) + ) + upper = ( + at.as_tensor_variable(floatX(upper)) + if upper is not None + else at.as_tensor_variable(np.inf) + ) + res = super().dist([mu, sigma, lower, upper], testval=initval, **kwargs) + return res def logp(value, mu, sigma, lower, upper): """ - Calculate log-probability of TruncatedNormal distribution at specified value. + Calculate log-probaDirichletMultinomialbility of TruncatedNormal distribution at specified value. Parameters ---------- @@ -701,54 +715,32 @@ def logp(value, mu, sigma, lower, upper): ------- TensorVariable """ - mu, sigma = _truncated_normal_prepare_mu_sigma_params(mu, sigma, tau=None, sd=None) - print(lower.eval()) - lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper( - lower, upper - ) - print(lower.eval()) - norm = _truncated_normal_normalization(mu, sigma, lower, upper) + + if lower is not None and upper is not None: + lcdf_a = normal_lcdf(mu, sigma, lower) + lcdf_b = normal_lcdf(mu, sigma, upper) + lsf_a = normal_lccdf(mu, sigma, lower) + lsf_b = normal_lccdf(mu, sigma, upper) + norm = at.switch(lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a)) + elif lower is not None: + norm = normal_lccdf(mu, sigma, lower) + elif upper is not None: + norm = normal_lcdf(mu, sigma, upper) + else: + norm = 0.0 + logp = Normal.logp(value, mu=mu, sigma=sigma) - norm bounds = [sigma > 0] - if lower_check is not None: + if lower is not None: bounds.append(value >= lower) - if upper_check is not None: + if upper is not None: bounds.append(value <= upper) - print([e.eval() for e in bounds]) return bound(logp, *bounds) def _distr_parameters_for_repr(self): return ["mu", "sigma", "lower", "upper"] -def _truncated_normal_prepare_mu_sigma_params(mu, sigma, tau, sd): - sigma = sd if sd is not None else sigma - tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) - sigma = at.as_tensor_variable(sigma) - tau = at.as_tensor_variable(tau) - mu = at.as_tensor_variable(floatX(mu)) - assert_negative_support(sigma, "sigma", "TruncatedNormal") - assert_negative_support(tau, "tau", "TruncatedNormal") - return mu, sigma - - -def _truncated_normal_normalization(mu, sigma, lower, upper): - lower, lower_check, upper, upper_check = _truncated_normal_prepare_lower_and_upper(lower, upper) - if lower_check is None and upper_check is None: - return 0.0 - - if lower_check is not None and upper_check is not None: - lcdf_a = normal_lcdf(mu, sigma, lower) - lcdf_b = normal_lcdf(mu, sigma, upper) - lsf_a = normal_lccdf(mu, sigma, lower) - lsf_b = normal_lccdf(mu, sigma, upper) - return at.switch(lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a)) - if lower_check is not None: - return normal_lccdf(mu, sigma, lower) - else: - return normal_lcdf(mu, sigma, upper) - - def _truncated_normal_prepare_lower_and_upper(lower, upper): lower_check = at.as_tensor_variable(floatX(lower)) if lower is not None else lower upper_check = at.as_tensor_variable(floatX(upper)) if upper is not None else upper