From 244fb97b01ad0f3dadf5c3837b65839e2a59a0e8 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Mon, 4 Mar 2024 18:16:32 -0500 Subject: [PATCH] Refactor `get_tau_sigma` and support lists of variables (#7185) --- pymc/distributions/continuous.py | 39 +++++++------------ tests/distributions/test_continuous.py | 52 +++++++++++++++++--------- 2 files changed, 49 insertions(+), 42 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 1ad937b37f1..aa722787da5 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -229,32 +229,21 @@ def get_tau_sigma(tau=None, sigma=None): ----- If neither tau nor sigma is provided, returns (1., 1.) """ - if tau is None: - if sigma is None: - sigma = 1.0 - tau = 1.0 - else: - if isinstance(sigma, Variable): - # Keep tau negative, if sigma was negative, so that it will fail when used - tau = (sigma**-2.0) * pt.sign(sigma) - else: - sigma_ = np.asarray(sigma) - if np.any(sigma_ <= 0): - raise ValueError("sigma must be positive") - tau = sigma_**-2.0 - + if tau is not None and sigma is not None: + raise ValueError("Can't pass both tau and sigma") + if tau is None and sigma is None: + sigma = pt.as_tensor_variable(1.0) + tau = pt.as_tensor_variable(1.0) + elif tau is None: + sigma = pt.as_tensor_variable(sigma) + # Keep tau negative, if sigma was negative, so that it will + # fail when used + tau = (sigma**-2.0) * pt.sign(sigma) else: - if sigma is not None: - raise ValueError("Can't pass both tau and sigma") - else: - if isinstance(tau, Variable): - # Keep sigma negative, if tau was negative, so that it will fail when used - sigma = pt.abs(tau) ** (-0.5) * pt.sign(tau) - else: - tau_ = np.asarray(tau) - if np.any(tau_ <= 0): - raise ValueError("tau must be positive") - sigma = tau_**-0.5 + tau = pt.as_tensor_variable(tau) + # Keep tau negative, if sigma was negative, so that it will + # fail when used + sigma = pt.abs(tau) ** -0.5 * pt.sign(tau) return tau, sigma diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 901618dc28e..c8119d1482a 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -13,7 +13,6 @@ # limitations under the License. import functools as ft -import warnings import numpy as np import numpy.testing as npt @@ -998,26 +997,45 @@ def scipy_logcdf(value, mu, sigma, lower, upper): assert np.isinf(logp[2]) def test_get_tau_sigma(self): - # Fail on warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") + sigma = np.array(2) + tau, _ = get_tau_sigma(sigma=sigma) + npt.assert_almost_equal(tau.eval(), 1.0 / sigma**2) - sigma = np.array(2) - npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma]) + tau = np.array(2) + _, sigma = get_tau_sigma(tau=tau) + npt.assert_almost_equal(sigma.eval(), tau**-0.5) - tau = np.array(2) - npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5]) + tau, _ = get_tau_sigma(sigma=pt.constant(-2)) + npt.assert_almost_equal(tau.eval(), -0.25) - tau, _ = get_tau_sigma(sigma=pt.constant(-2)) - npt.assert_almost_equal(tau.eval(), -0.25) + _, sigma = get_tau_sigma(tau=pt.constant(-2)) + npt.assert_almost_equal(sigma.eval(), -1.0 / np.sqrt(2.0)) - _, sigma = get_tau_sigma(tau=pt.constant(-2)) - npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2)) + sigma = [1, 2] + tau, _ = get_tau_sigma(sigma=sigma) + npt.assert_almost_equal(tau.eval(), 1.0 / np.array(sigma) ** 2) - sigma = [1, 2] - npt.assert_almost_equal( - get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] - ) + # Test null arguments + tau, sigma = get_tau_sigma() + npt.assert_almost_equal(tau.eval(), 1.0) + npt.assert_almost_equal(sigma.eval(), 1.0) + + # Test exception upon passing both sigma and tau + msg = "Can't pass both tau and sigma" + with pytest.raises(ValueError, match=msg): + _, _ = get_tau_sigma(sigma=1.0, tau=1.0) + + # These are regression test for #6988: Check that get_tau_sigma works + # for lists of tensors + sigma = [pt.constant(2), pt.constant(2)] + expect_tau = np.array([0.25, 0.25]) + tau, _ = get_tau_sigma(sigma=sigma) + npt.assert_almost_equal(tau.eval(), expect_tau) + + tau = [pt.constant(2), pt.constant(2)] + expect_sigma = np.array([2.0, 2.0]) ** -0.5 + _, sigma = get_tau_sigma(tau=tau) + npt.assert_almost_equal(sigma.eval(), expect_sigma) @pytest.mark.parametrize( "value,mu,sigma,nu,logp", @@ -2042,7 +2060,7 @@ class TestStudentTLam(BaseTestDistributionRandom): lam, sigma = get_tau_sigma(tau=2.0) pymc_dist_params = {"nu": 5.0, "mu": -1.0, "lam": lam} expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "lam": sigma} - reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma} + reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma.eval()} reference_dist = seeded_scipy_distribution_builder("t") checks_to_run = ["check_pymc_params_match_rv_op"]