From 91a90e7d60e4b08644b791b693b68e8ee66cec07 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 30 Jun 2021 09:37:59 +0200 Subject: [PATCH] Simplify None bound replacement --- pymc3/distributions/continuous.py | 73 ++++++++++++------------------- pymc3/tests/test_distributions.py | 23 ++++++++-- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 014e7ebade2..39f30b84119 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -162,18 +162,17 @@ def transform_params(rv_var): if cls.bound_args_indices[1] is not None: upper = args[cls.bound_args_indices[1]] - lower = ( - at.as_tensor_variable(lower) - if lower is not None - and (isinstance(lower, TensorConstant) and not np.all(lower.value == -np.inf)) - else None - ) - upper = ( - at.as_tensor_variable(upper) - if upper is not None - and (isinstance(upper, TensorConstant) and not np.all(upper.value == np.inf)) - else None - ) + if lower is not None: + if isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf): + lower = None + else: + lower = at.as_tensor_variable(lower) + + if upper is not None: + if isinstance(upper, TensorConstant) and np.all(upper.value == np.inf): + upper = None + else: + upper = at.as_tensor_variable(upper) return lower, upper @@ -691,37 +690,24 @@ def dist( ) -> RandomVariable: sigma = sd if sd is not None else sigma tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) - single_param_length = 0 if not sigma.shape else sigma.shape[0] sigma = at.as_tensor_variable(sigma) tau = at.as_tensor_variable(tau) mu = at.as_tensor_variable(floatX(mu)) assert_negative_support(sigma, "sigma", "TruncatedNormal") assert_negative_support(tau, "tau", "TruncatedNormal") - if lower is None and upper is None: - initval = mu - elif lower is None and upper is not None: - initval = upper - 1.0 - elif lower is not None and upper is None: - initval = lower + 1.0 - else: - initval = (lower + upper) / 2 + # if lower is None and upper is None: + # initval = mu + # elif lower is None and upper is not None: + # initval = upper - 1.0 + # elif lower is not None and upper is None: + # initval = lower + 1.0 + # else: + # initval = (lower + upper) / 2 - def _handle_missing_bound(bound_value, expected_input_length, replacement_value): - if bound_value is not None: - return at.as_tensor_variable(floatX(bound_value)) - elif expected_input_length == 0: - return at.as_tensor_variable(floatX(replacement_value)) - else: - return at.as_tensor_variable( - floatX(np.repeat(replacement_value, expected_input_length)) - ) - - lower = _handle_missing_bound(lower, single_param_length, -np.inf) - upper = _handle_missing_bound(upper, single_param_length, np.inf) - - res = super().dist([mu, sigma, lower, upper], testval=initval, **kwargs) - return res + lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf) + upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf) + return super().dist([mu, sigma, lower, upper], **kwargs) def logp( value, @@ -743,8 +729,8 @@ def logp( ------- TensorVariable """ - unbounded_lower = isinstance(lower, TensorConstant) and lower.value == -np.inf - unbounded_upper = isinstance(upper, TensorConstant) and upper.value == np.inf + unbounded_lower = isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf) + unbounded_upper = isinstance(upper, TensorConstant) and np.all(upper.value == np.inf) if not unbounded_lower and not unbounded_upper: lcdf_a = normal_lcdf(mu, sigma, lower) @@ -760,16 +746,15 @@ def logp( norm = 0.0 logp = Normal.logp(value, mu=mu, sigma=sigma) - norm - bounds = [sigma > 0] - if lower is not None: + bounds = [] + if not unbounded_lower: bounds.append(value >= lower) - if upper is not None: + if not unbounded_upper: bounds.append(value <= upper) + if not unbounded_lower and not unbounded_upper: + bounds.append(lower <= upper) return bound(logp, *bounds) - def _distr_parameters_for_repr(self): - return ["mu", "sigma", "lower", "upper"] - class HalfNormal(PositiveContinuous): r""" diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index e75dfe9f700..de0fb50323d 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -2769,9 +2769,26 @@ def test_missing_upper_bound_array(self): ) dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) - assert all(a == b for a, b in zip(dist_params[2].value, np.array([-1, 0]))) - assert all(a == b for a, b in zip(dist_params[3].value, np.array([np.inf, np.inf]))) - assert all(a == b for a, b in zip(lower.value, np.array([-1, 0]))) + assert np.array_equal(dist_params[2].value, [-1, 0]) + assert dist_params[3].value == np.inf + assert np.array_equal(lower.value, [-1, 0]) + assert upper is None + + def test_missing_partial_upper_bound_array(self): + bounded_rv_name = "upper_bounded" + with Model() as model: + TruncatedNormal( + bounded_rv_name, + mu=np.array([1, 1]), + sigma=np.array([2, 3]), + lower=np.array([-1.0, -np.inf]), + upper=None, + ) + dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + + assert np.array_equal(dist_params[2].value, [-1, -np.inf]) + assert dist_params[3].value == np.inf + assert np.array_equal(lower.value, [-1, -np.inf]) assert upper is None def test_missing_upper_bound_with_richer_context(self):