From cc2feab6758dde48a2049cf96cd5b6f3957aeae7 Mon Sep 17 00:00:00 2001 From: jan-matthis Date: Tue, 16 Nov 2021 13:22:06 +0100 Subject: [PATCH 1/3] LADJ fix --- sbibm/tasks/two_moons/task.py | 2 +- sbibm/utils/pyro.py | 7 +------ tests/tasks/test_task_interface.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/sbibm/tasks/two_moons/task.py b/sbibm/tasks/two_moons/task.py index 2d5e019f..56a299c9 100644 --- a/sbibm/tasks/two_moons/task.py +++ b/sbibm/tasks/two_moons/task.py @@ -165,7 +165,7 @@ def _get_transforms( *args, **kwargs: Any, ) -> Dict[str, Any]: - return {"parameters": torch.distributions.transforms.identity_transform} + return {"parameters": torch.distributions.transforms.IndependentTransform(torch.distributions.transforms.identity_transform, 1) } def _get_log_prob_fn( self, diff --git a/sbibm/utils/pyro.py b/sbibm/utils/pyro.py index 53483313..bc12a458 100644 --- a/sbibm/utils/pyro.py +++ b/sbibm/utils/pyro.py @@ -81,6 +81,7 @@ def get_log_prob_fn( model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) has_enumerable_sites = False + needs_independent_transform = True for name, node in model_trace.iter_stochastic_nodes(): fn = node["fn"] @@ -99,12 +100,6 @@ def get_log_prob_fn( transforms[name] = biject_to(fn.support).inv else: transforms[name] = dist.transforms.identity_transform - # Reinterpret batch dimensions of transform to get log abs det jac summed over - # parameter dimensions. - if not isinstance(transforms[name], IndependentTransform): - transforms[name] = IndependentTransform( - transforms[name], reinterpreted_batch_ndims=1 - ) if implementation == "pyro": trace_prob_evaluator = TraceEinsumEvaluator( diff --git a/tests/tasks/test_task_interface.py b/tests/tasks/test_task_interface.py index b5540e7d..dc21f971 100644 --- a/tests/tasks/test_task_interface.py +++ b/tests/tasks/test_task_interface.py @@ -101,3 +101,17 @@ def test_reference_posterior_not_called(task_name): reference_samples = task.get_reference_posterior_samples(num_observation=1) assert task is not None + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_transforms_shapes(task_name, batch_size=5): + task = get_task(task_name) + prior = task.get_prior() + samples = prior(num_samples=batch_size) + + transforms = task._get_transforms(True)["parameters"] + + ladj_shape = transforms.log_abs_det_jacobian(transforms(samples), samples).shape + assert ladj_shape == torch.Size([batch_size]) + + assert transforms is not None \ No newline at end of file From 4dc8926b84139df78cd45e2b90ad1b9d00ae0e9e Mon Sep 17 00:00:00 2001 From: jan-matthis Date: Tue, 16 Nov 2021 13:25:02 +0100 Subject: [PATCH 2/3] Remove unneeded variable --- sbibm/utils/pyro.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sbibm/utils/pyro.py b/sbibm/utils/pyro.py index bc12a458..7a8ae485 100644 --- a/sbibm/utils/pyro.py +++ b/sbibm/utils/pyro.py @@ -81,7 +81,6 @@ def get_log_prob_fn( model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) has_enumerable_sites = False - needs_independent_transform = True for name, node in model_trace.iter_stochastic_nodes(): fn = node["fn"] From 8216dd1e2f5b45f913b3db51180945ce6830fc43 Mon Sep 17 00:00:00 2001 From: jan-matthis Date: Tue, 16 Nov 2021 13:34:15 +0100 Subject: [PATCH 3/3] Reinterpret batch dimensions in case of identity transform --- sbibm/utils/pyro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbibm/utils/pyro.py b/sbibm/utils/pyro.py index 7a8ae485..faf9b108 100644 --- a/sbibm/utils/pyro.py +++ b/sbibm/utils/pyro.py @@ -98,7 +98,7 @@ def get_log_prob_fn( if automatic_transform_enabled: transforms[name] = biject_to(fn.support).inv else: - transforms[name] = dist.transforms.identity_transform + transforms[name] = dist.transforms.IndependentTransform(dist.transforms.identity_transform, 1) if implementation == "pyro": trace_prob_evaluator = TraceEinsumEvaluator(