Skip to content

Commit

Permalink
Refactor get_tau_sigma and support lists of variables (#7185)
Browse files Browse the repository at this point in the history
  • Loading branch information
tvwenger authored Mar 4, 2024
1 parent 47f6d9e commit 244fb97
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 42 deletions.
39 changes: 14 additions & 25 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,21 @@ def get_tau_sigma(tau=None, sigma=None):
-----
If neither tau nor sigma is provided, returns (1., 1.)
"""
if tau is None:
if sigma is None:
sigma = 1.0
tau = 1.0
else:
if isinstance(sigma, Variable):
# Keep tau negative, if sigma was negative, so that it will fail when used
tau = (sigma**-2.0) * pt.sign(sigma)
else:
sigma_ = np.asarray(sigma)
if np.any(sigma_ <= 0):
raise ValueError("sigma must be positive")
tau = sigma_**-2.0

if tau is not None and sigma is not None:
raise ValueError("Can't pass both tau and sigma")
if tau is None and sigma is None:
sigma = pt.as_tensor_variable(1.0)
tau = pt.as_tensor_variable(1.0)
elif tau is None:
sigma = pt.as_tensor_variable(sigma)
# Keep tau negative, if sigma was negative, so that it will
# fail when used
tau = (sigma**-2.0) * pt.sign(sigma)
else:
if sigma is not None:
raise ValueError("Can't pass both tau and sigma")
else:
if isinstance(tau, Variable):
# Keep sigma negative, if tau was negative, so that it will fail when used
sigma = pt.abs(tau) ** (-0.5) * pt.sign(tau)
else:
tau_ = np.asarray(tau)
if np.any(tau_ <= 0):
raise ValueError("tau must be positive")
sigma = tau_**-0.5
tau = pt.as_tensor_variable(tau)
# Keep tau negative, if sigma was negative, so that it will
# fail when used
sigma = pt.abs(tau) ** -0.5 * pt.sign(tau)

return tau, sigma

Expand Down
52 changes: 35 additions & 17 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import functools as ft
import warnings

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -998,26 +997,45 @@ def scipy_logcdf(value, mu, sigma, lower, upper):
assert np.isinf(logp[2])

def test_get_tau_sigma(self):
# Fail on warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
sigma = np.array(2)
tau, _ = get_tau_sigma(sigma=sigma)
npt.assert_almost_equal(tau.eval(), 1.0 / sigma**2)

sigma = np.array(2)
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])
tau = np.array(2)
_, sigma = get_tau_sigma(tau=tau)
npt.assert_almost_equal(sigma.eval(), tau**-0.5)

tau = np.array(2)
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
npt.assert_almost_equal(tau.eval(), -0.25)

tau, _ = get_tau_sigma(sigma=pt.constant(-2))
npt.assert_almost_equal(tau.eval(), -0.25)
_, sigma = get_tau_sigma(tau=pt.constant(-2))
npt.assert_almost_equal(sigma.eval(), -1.0 / np.sqrt(2.0))

_, sigma = get_tau_sigma(tau=pt.constant(-2))
npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2))
sigma = [1, 2]
tau, _ = get_tau_sigma(sigma=sigma)
npt.assert_almost_equal(tau.eval(), 1.0 / np.array(sigma) ** 2)

sigma = [1, 2]
npt.assert_almost_equal(
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
)
# Test null arguments
tau, sigma = get_tau_sigma()
npt.assert_almost_equal(tau.eval(), 1.0)
npt.assert_almost_equal(sigma.eval(), 1.0)

# Test exception upon passing both sigma and tau
msg = "Can't pass both tau and sigma"
with pytest.raises(ValueError, match=msg):
_, _ = get_tau_sigma(sigma=1.0, tau=1.0)

# These are regression test for #6988: Check that get_tau_sigma works
# for lists of tensors
sigma = [pt.constant(2), pt.constant(2)]
expect_tau = np.array([0.25, 0.25])
tau, _ = get_tau_sigma(sigma=sigma)
npt.assert_almost_equal(tau.eval(), expect_tau)

tau = [pt.constant(2), pt.constant(2)]
expect_sigma = np.array([2.0, 2.0]) ** -0.5
_, sigma = get_tau_sigma(tau=tau)
npt.assert_almost_equal(sigma.eval(), expect_sigma)

@pytest.mark.parametrize(
"value,mu,sigma,nu,logp",
Expand Down Expand Up @@ -2042,7 +2060,7 @@ class TestStudentTLam(BaseTestDistributionRandom):
lam, sigma = get_tau_sigma(tau=2.0)
pymc_dist_params = {"nu": 5.0, "mu": -1.0, "lam": lam}
expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "lam": sigma}
reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma}
reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma.eval()}
reference_dist = seeded_scipy_distribution_builder("t")
checks_to_run = ["check_pymc_params_match_rv_op"]

Expand Down

0 comments on commit 244fb97

Please sign in to comment.