Skip to content

Commit

Permalink
Simplify None bound replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 30, 2021
1 parent 74aeafc commit 91a90e7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 47 deletions.
73 changes: 29 additions & 44 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,17 @@ def transform_params(rv_var):
if cls.bound_args_indices[1] is not None:
upper = args[cls.bound_args_indices[1]]

lower = (
at.as_tensor_variable(lower)
if lower is not None
and (isinstance(lower, TensorConstant) and not np.all(lower.value == -np.inf))
else None
)
upper = (
at.as_tensor_variable(upper)
if upper is not None
and (isinstance(upper, TensorConstant) and not np.all(upper.value == np.inf))
else None
)
if lower is not None:
if isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf):
lower = None
else:
lower = at.as_tensor_variable(lower)

if upper is not None:
if isinstance(upper, TensorConstant) and np.all(upper.value == np.inf):
upper = None
else:
upper = at.as_tensor_variable(upper)

return lower, upper

Expand Down Expand Up @@ -691,37 +690,24 @@ def dist(
) -> RandomVariable:
sigma = sd if sd is not None else sigma
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
single_param_length = 0 if not sigma.shape else sigma.shape[0]
sigma = at.as_tensor_variable(sigma)
tau = at.as_tensor_variable(tau)
mu = at.as_tensor_variable(floatX(mu))
assert_negative_support(sigma, "sigma", "TruncatedNormal")
assert_negative_support(tau, "tau", "TruncatedNormal")

if lower is None and upper is None:
initval = mu
elif lower is None and upper is not None:
initval = upper - 1.0
elif lower is not None and upper is None:
initval = lower + 1.0
else:
initval = (lower + upper) / 2
# if lower is None and upper is None:
# initval = mu
# elif lower is None and upper is not None:
# initval = upper - 1.0
# elif lower is not None and upper is None:
# initval = lower + 1.0
# else:
# initval = (lower + upper) / 2

def _handle_missing_bound(bound_value, expected_input_length, replacement_value):
if bound_value is not None:
return at.as_tensor_variable(floatX(bound_value))
elif expected_input_length == 0:
return at.as_tensor_variable(floatX(replacement_value))
else:
return at.as_tensor_variable(
floatX(np.repeat(replacement_value, expected_input_length))
)

lower = _handle_missing_bound(lower, single_param_length, -np.inf)
upper = _handle_missing_bound(upper, single_param_length, np.inf)

res = super().dist([mu, sigma, lower, upper], testval=initval, **kwargs)
return res
lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf)
return super().dist([mu, sigma, lower, upper], **kwargs)

def logp(
value,
Expand All @@ -743,8 +729,8 @@ def logp(
-------
TensorVariable
"""
unbounded_lower = isinstance(lower, TensorConstant) and lower.value == -np.inf
unbounded_upper = isinstance(upper, TensorConstant) and upper.value == np.inf
unbounded_lower = isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf)
unbounded_upper = isinstance(upper, TensorConstant) and np.all(upper.value == np.inf)

if not unbounded_lower and not unbounded_upper:
lcdf_a = normal_lcdf(mu, sigma, lower)
Expand All @@ -760,16 +746,15 @@ def logp(
norm = 0.0

logp = Normal.logp(value, mu=mu, sigma=sigma) - norm
bounds = [sigma > 0]
if lower is not None:
bounds = []
if not unbounded_lower:
bounds.append(value >= lower)
if upper is not None:
if not unbounded_upper:
bounds.append(value <= upper)
if not unbounded_lower and not unbounded_upper:
bounds.append(lower <= upper)
return bound(logp, *bounds)

def _distr_parameters_for_repr(self):
return ["mu", "sigma", "lower", "upper"]


class HalfNormal(PositiveContinuous):
r"""
Expand Down
23 changes: 20 additions & 3 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,9 +2769,26 @@ def test_missing_upper_bound_array(self):
)
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)

assert all(a == b for a, b in zip(dist_params[2].value, np.array([-1, 0])))
assert all(a == b for a, b in zip(dist_params[3].value, np.array([np.inf, np.inf])))
assert all(a == b for a, b in zip(lower.value, np.array([-1, 0])))
assert np.array_equal(dist_params[2].value, [-1, 0])
assert dist_params[3].value == np.inf
assert np.array_equal(lower.value, [-1, 0])
assert upper is None

def test_missing_partial_upper_bound_array(self):
bounded_rv_name = "upper_bounded"
with Model() as model:
TruncatedNormal(
bounded_rv_name,
mu=np.array([1, 1]),
sigma=np.array([2, 3]),
lower=np.array([-1.0, -np.inf]),
upper=None,
)
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)

assert np.array_equal(dist_params[2].value, [-1, -np.inf])
assert dist_params[3].value == np.inf
assert np.array_equal(lower.value, [-1, -np.inf])
assert upper is None

def test_missing_upper_bound_with_richer_context(self):
Expand Down

0 comments on commit 91a90e7

Please sign in to comment.