diff --git a/pyro/infer/reparam/stable.py b/pyro/infer/reparam/stable.py index c54aa7d287..9717135ea5 100644 --- a/pyro/infer/reparam/stable.py +++ b/pyro/infer/reparam/stable.py @@ -218,7 +218,7 @@ def apply(self, msg): t_scale = skew_abs.pow(a_inv) s_scale = (1 - skew_abs).pow(a_inv) shift = _safe_shift(a, fn.skew, t_scale, skew_abs) - loc = fn.loc + fn.scale * (fn.skew.sign() * t_scale * t + shift) + loc = fn.loc + fn.scale * (fn.skew.detach().sign() * t_scale * t + shift) scale = fn.scale * s_scale * z.sqrt() * (math.pi / 4 * a).cos().pow(a_inv) scale = scale.clamp(min=torch.finfo(scale.dtype).tiny) @@ -229,7 +229,7 @@ def apply(self, msg): def _unsafe_shift(a, skew, t_scale): # At a=1 the lhs has a root and the rhs has an asymptote. - return (skew.sign() * t_scale - skew) * (math.pi / 2 * a).tan() + return (skew.detach().sign() * t_scale - skew) * (math.pi / 2 * a).tan() def _safe_shift(a, skew, t_scale, skew_abs): diff --git a/tests/infer/reparam/test_stable.py b/tests/infer/reparam/test_stable.py index efaa810da0..deda2b4ddf 100644 --- a/tests/infer/reparam/test_stable.py +++ b/tests/infer/reparam/test_stable.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import logging + import pytest import torch from scipy.stats import ks_2samp @@ -9,18 +11,23 @@ import pyro import pyro.distributions as dist from pyro import poutine +from pyro.distributions import constraints from pyro.distributions.torch_distribution import MaskedDistribution -from pyro.infer import Trace_ELBO +from pyro.infer import SVI, Trace_ELBO from pyro.infer.autoguide import AutoNormal +from pyro.infer.mcmc import MCMC, NUTS from pyro.infer.reparam import ( LatentStableReparam, StableReparam, SymmetricStableReparam, ) -from tests.common import assert_close, xfail_param +from pyro.optim import ClippedAdam +from tests.common import assert_close from .util import check_init_reparam +logger = logging.getLogger(__name__) + # Test helper to extract a few absolute moments from univariate samples. # This uses abs moments because Stable variance is infinite. @@ -30,15 +37,7 @@ def get_moments(x): return torch.cat([x.mean(0, keepdim=True), (x - points).abs().mean(1)]) -@pytest.mark.parametrize( - "shape", - [ - (), - xfail_param(4, reason="flaky, https://github.com/pyro-ppl/pyro/issues/3214"), - (2, 3), - ], - ids=str, -) +@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) @pytest.mark.parametrize("Reparam", [LatentStableReparam, StableReparam]) def test_stable(Reparam, shape): stability = torch.empty(shape).uniform_(1.5, 2.0).requires_grad_() @@ -165,3 +164,134 @@ def model(): return pyro.sample("x", dist.Stable(stability, skew)) check_init_reparam(model, Reparam()) + + +@pytest.mark.stage("integration", "integration_batch_1") +@pytest.mark.parametrize( + "stability, skew, scale, loc", + [ + (1.9, 0.0, 2.0, 1.0), + (0.8, 0.0, 3.0, 2.0), + ], +) +def test_symmetric_stable_mle(stability, skew, scale, loc): + # Regression test for https://github.com/pyro-ppl/pyro/issues/3280 + assert skew == 0.0 + data = dist.Stable(stability, skew, scale, loc).sample([10000]) + + @poutine.reparam(config={"x": SymmetricStableReparam()}) + def mle_model(): + a = pyro.param("a", torch.tensor(1.9), constraint=constraints.interval(0, 2)) + b = 0.0 + c = pyro.param("c", torch.tensor(1.0), constraint=constraints.positive) + d = pyro.param("d", torch.tensor(0.0), constraint=constraints.real) + with pyro.plate("data", len(data)): + pyro.sample("x", dist.Stable(a, b, c, d), obs=data) + + num_steps = 1001 + guide = AutoNormal(mle_model) + optim = ClippedAdam({"clip_norm": 100, "lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) + svi = SVI(mle_model, guide, optim, Trace_ELBO()) + for step in range(num_steps): + loss = svi.step() / len(data) + if step % 100 == 0: + logger.info("step %d loss = %g", step, loss) + + # Check loss against a true model. + @poutine.reparam(config={"x": SymmetricStableReparam()}) + def true_model(): + with pyro.plate("data", len(data)): + pyro.sample("x", dist.Stable(stability, skew, scale, loc), obs=data) + + actual_loss = Trace_ELBO().loss(mle_model, guide) / len(data) + expected_loss = Trace_ELBO().loss(true_model, guide) / len(data) + assert_close(actual_loss, expected_loss, atol=0.33) + + # Check parameter estimates. + actual = {name: float(pyro.param(name).data) for name in "acd"} + assert_close(actual["a"], stability, atol=0.1) + assert_close(actual["c"], scale, atol=0.1, rtol=0.1) + assert_close(actual["d"], loc, atol=0.1) + + +@pytest.mark.stage("integration", "integration_batch_1") +@pytest.mark.parametrize( + "stability, skew, scale, loc", + [ + (1.9, 0.0, 2.0, 1.0), + (0.8, 0.0, 3.0, 2.0), + (1.8, 0.8, 4.0, 3.0), + ], +) +def test_stable_mle(stability, skew, scale, loc): + # Regression test for https://github.com/pyro-ppl/pyro/issues/3280 + data = dist.Stable(stability, skew, scale, loc).sample([10000]) + + @poutine.reparam(config={"x": StableReparam()}) + def mle_model(): + a = pyro.param("a", torch.tensor(1.9), constraint=constraints.interval(0, 2)) + b = pyro.param("b", torch.tensor(0.0), constraint=constraints.interval(-1, 1)) + c = pyro.param("c", torch.tensor(1.0), constraint=constraints.positive) + d = pyro.param("d", torch.tensor(0.0), constraint=constraints.real) + with pyro.plate("data", len(data)): + pyro.sample("x", dist.Stable(a, b, c, d), obs=data) + + num_steps = 1001 + guide = AutoNormal(mle_model) + optim = ClippedAdam({"clip_norm": 100, "lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) + svi = SVI(mle_model, guide, optim, Trace_ELBO()) + for step in range(num_steps): + loss = svi.step() / len(data) + if step % 100 == 0: + logger.info("step %d loss = %g", step, loss) + + # Check loss against a true model. + @poutine.reparam(config={"x": StableReparam()}) + def true_model(): + with pyro.plate("data", len(data)): + pyro.sample("x", dist.Stable(stability, skew, scale, loc), obs=data) + + actual_loss = Trace_ELBO().loss(mle_model, guide) / len(data) + expected_loss = Trace_ELBO().loss(true_model, guide) / len(data) + assert_close(actual_loss, expected_loss, atol=0.1) + + # Check parameter estimates. + actual = {name: float(pyro.param(name).data) for name in "abcd"} + assert_close(actual["a"], stability, atol=0.1) + assert_close(actual["b"], skew, atol=0.1) + assert_close(actual["c"], scale, atol=0.1, rtol=0.1) + assert_close(actual["d"], loc, atol=0.1) + + +@pytest.mark.stage("integration", "integration_batch_1") +@pytest.mark.parametrize( + "stability, skew, scale, loc", + [ + (1.9, 0.0, 2.0, 1.0), + (0.8, 0.0, 3.0, 2.0), + (1.8, 0.8, 4.0, 3.0), + ], +) +def test_stable_mcmc(stability, skew, scale, loc): + # Regression test for https://github.com/pyro-ppl/pyro/issues/3280 + data = dist.Stable(stability, skew, scale, loc).sample([1000]) + + @poutine.reparam(config={"x": StableReparam()}) + def model(): + with poutine.mask(mask=False): # flat prior + a = pyro.sample("a", dist.Uniform(0, 2)) + b = pyro.sample("b", dist.Uniform(-1, 1)) + c = pyro.sample("c", dist.Exponential(1)) + d = pyro.sample("d", dist.Normal(0, 1)) + with pyro.plate("data", len(data)): + pyro.sample("x", dist.Stable(a, b, c, d), obs=data) + + kernel = NUTS(model) + mcmc = MCMC(kernel, num_samples=400, warmup_steps=200) + mcmc.run() + samples = mcmc.get_samples() + actual = {k: v.mean().item() for k, v in samples.items()} + assert_close(actual["a"], stability, atol=0.1) + assert_close(actual["b"], skew, atol=0.1) + assert_close(actual["c"], scale, atol=0.1, rtol=0.1) + assert_close(actual["d"], loc, atol=0.1)