From 4df9983db7ca03287d47a8fcd76c7649de78b8bb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 7 Jun 2021 23:53:53 -0500 Subject: [PATCH] fix loc scale reparam with center=1 --- numpyro/infer/reparam.py | 2 +- test/infer/test_reparam.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/reparam.py b/numpyro/infer/reparam.py index e7363a3c4..54bdb729d 100644 --- a/numpyro/infer/reparam.py +++ b/numpyro/infer/reparam.py @@ -89,7 +89,7 @@ def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): - return name, fn, obs + return fn, obs event_shape = fn.event_shape fn, expand_shape, event_dim = self._unwrap(fn) diff --git a/test/infer/test_reparam.py b/test/infer/test_reparam.py index 98e992a86..2f335bdfb 100644 --- a/test/infer/test_reparam.py +++ b/test/infer/test_reparam.py @@ -217,9 +217,9 @@ def get_expected_probe(loc, scale): return get_moments(trace["x"]["value"]) if "dist_type" == "Normal": - reparam = LocScaleReparam() + reparam = LocScaleReparam(centered) else: - reparam = LocScaleReparam(shape_params=["df"]) + reparam = LocScaleReparam(centered, shape_params=["df"]) def get_actual_probe(loc, scale): with numpyro.handlers.trace() as trace: