From 360fa100d0662dc5c49b0fa91c038987ca2511ed Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Wed, 6 Mar 2024 23:23:42 +0100 Subject: [PATCH 1/7] add pyro.nn.module.PyroModuleList --- pyro/nn/__init__.py | 9 ++++++++- pyro/nn/module.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/pyro/nn/__init__.py b/pyro/nn/__init__.py index 3642d0411c..6da1eccff9 100644 --- a/pyro/nn/__init__.py +++ b/pyro/nn/__init__.py @@ -9,7 +9,13 @@ MaskedLinear, ) from pyro.nn.dense_nn import ConditionalDenseNN, DenseNN -from pyro.nn.module import PyroModule, PyroParam, PyroSample, pyro_method +from pyro.nn.module import ( + PyroModule, + PyroParam, + PyroSample, + pyro_method, + PyroModuleList, +) __all__ = [ "AutoRegressiveNN", @@ -21,4 +27,5 @@ "PyroParam", "PyroSample", "pyro_method", + "PyroModuleList", ] diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 323fe470a5..a7b3c21ec4 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -14,9 +14,10 @@ """ import functools import inspect +from typing import Union import weakref from collections import OrderedDict, namedtuple - +from torch._jit_internal import _copy_to_script_wrapper import torch from torch.distributions import constraints, transform_to @@ -826,3 +827,30 @@ def __set__(self, obj, value): PyroModule[torch.nn.RNNBase]._flat_weights = _FlatWeightsDescriptor() + + +# pyro module list +# using pyro.nn.PyroModule[torch.nn.ModuleList] can cause issues when +# slice-indexing nested PyroModuleLists, so we define a separate PyroModuleList +# class that overwrites the __getitem__ method to return a torch.nn.ModuleList +# to not use self.__class__ in __getitem__, as that would call the +# PyroModule.__init__ without the parent module context, leading to a loss +# of the parent module's _pyro_name, and eventually, errors during sampling +# as parameter names may not be unique anymore +# The scenario is rare but happend. +# The fix could not be applied in torch directly, which is why we have to deal +# with it here, see https://github.com/pytorch/pytorch/issues/121008 +class PyroModuleList(PyroModule, torch.nn.ModuleList): + def __init__(self, modules): + PyroModule.__init__(self) + torch.nn.ModuleList.__init__(self, modules) + + @_copy_to_script_wrapper + def __getitem__( + self, idx: Union[int, slice] + ) -> Union[torch.nn.Module, "PyroModuleList"]: + if isinstance(idx, slice): + # return self.__class__(list(self._modules.values())[idx]) + return torch.nn.ModuleList(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] From fae8ba64127f88c65396a4c2ad4e65607ff7a3a3 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Mon, 11 Mar 2024 18:35:43 +0100 Subject: [PATCH 2/7] add modulelist tests to test_module.py --- tests/nn/test_module.py | 198 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 196 insertions(+), 2 deletions(-) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 67c4b98108..9809871c3f 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -2,17 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import io +import math +from typing import Callable, Iterable import warnings - import pytest import torch from torch import nn from torch.distributions import constraints, transform_to - import pyro import pyro.distributions as dist from pyro import poutine from pyro.infer import SVI, Trace_ELBO +from pyro.infer.autoguide.guides import AutoDiagonalNormal from pyro.nn.module import PyroModule, PyroParam, PyroSample, clear, to_pyro_module_ from pyro.optim import Adam from tests.common import assert_equal, xfail_param @@ -844,3 +845,196 @@ def forward(self, x, y): grad_params_func[k], torch.zeros_like(grad_params_func[k]) ), k assert torch.allclose(grad_params_autograd[k], grad_params_func[k]), k + + +class BNN(PyroModule): + # this is a vanilla Bayesian neural network implementation, nothing new or exiting here + def __init__( + self, + input_size: int, + hidden_layer_sizes: Iterable[int], + output_size: int, + use_new_module_list_type: bool, + ) -> None: + super().__init__() + + layer_sizes = ( + [(input_size, hidden_layer_sizes[0])] + + list(zip(hidden_layer_sizes[:-1], hidden_layer_sizes[1:])) + + [(hidden_layer_sizes[-1], output_size)] + ) + + layers = [ + pyro.nn.module.PyroModule[torch.nn.Linear](in_size, out_size) + for in_size, out_size in layer_sizes + ] + if use_new_module_list_type: + self.layers = pyro.nn.module.PyroModuleList(layers) + else: + self.layers = pyro.nn.module.PyroModule[torch.nn.ModuleList](layers) + + # make the layers Bayesian + for layer_idx, layer in enumerate(self.layers): + layer.weight = pyro.nn.module.PyroSample( + dist.Normal(0.0, 5.0 * math.sqrt(2 / layer_sizes[layer_idx][0])) + .expand( + [ + layer_sizes[layer_idx][1], + layer_sizes[layer_idx][0], + ] + ) + .to_event(2) + ) + layer.bias = pyro.nn.module.PyroSample( + dist.Normal(0.0, 5.0).expand([layer_sizes[layer_idx][1]]).to_event(1) + ) + + self.activation = torch.nn.Tanh() + self.output_size = output_size + + def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: + mean = self.layers[-1](x) + + if obs is not None: + with pyro.plate("data", x.shape[0]): + pyro.sample( + "obs", dist.Normal(mean, 0.1).to_event(self.output_size), obs=obs + ) + + return mean + + +class SliceIndexingModuleListBNN(BNN): + # I claim that it makes a difference whether slice-indexing is used or whether position-indexing is used + # when sub-pyromodule are wrapped in a PyroModule[torch.nn.ModuleList] + def __init__( + self, + input_size: int, + hidden_layer_sizes: Iterable[int], + output_size: int, + use_new_module_list_type: bool, + ) -> None: + super().__init__( + input_size, hidden_layer_sizes, output_size, use_new_module_list_type + ) + + def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: + for layer in self.layers[:-1]: + x = layer(x) + x = self.activation(x) + + return super().forward(x, obs=obs) + + +class PositionIndexingModuleListBNN(BNN): + # I claim that it makes a difference whether slice-indexing is used or whether position-indexing is used + # when sub-pyromodule are wrapped in a PyroModule[torch.nn.ModuleList] + def __init__( + self, + input_size: int, + hidden_layer_sizes: Iterable[int], + output_size: int, + use_new_module_list_type: bool, + ) -> None: + super().__init__( + input_size, hidden_layer_sizes, output_size, use_new_module_list_type + ) + + def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: + for i in range(len(self.layers) - 1): + x = self.layers[i](x) + x = self.activation(x) + + return super().forward(x, obs=obs) + + +class NestedBNN(pyro.nn.module.PyroModule): + # finally, the issue I want to describe occurs after the second "layer of nesting", + # i.e. when a PyroModule[ModuleList] is wrapped in a PyroModule[ModuleList] + def __init__(self, bnns: Iterable[BNN], use_new_module_list_type: bool) -> None: + super().__init__() + if use_new_module_list_type: + self.bnns = pyro.nn.module.PyroModuleList(bnns) + else: + self.bnns = pyro.nn.module.PyroModule[torch.nn.ModuleList](bnns) + + def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: + mean = sum([bnn(x) for bnn in self.bnns]) / len(self.bnns) + + with pyro.plate("data", x.shape[0]): + pyro.sample("obs", dist.Normal(mean, 0.1).to_event(1), obs=obs) + + return mean + + +def train_bnn(model: BNN, input_size: int) -> None: + pyro.clear_param_store() + + # small numbers for demo purposes + num_points = 20 + num_svi_iterations = 100 + + x = torch.linspace(0, 1, num_points).reshape((-1, input_size)) + y = torch.sin(2 * math.pi * x) + torch.randn(x.size()) * 0.1 + + guide = AutoDiagonalNormal(model) + adam = pyro.optim.Adam({"lr": 0.03}) + svi = SVI(model, guide, adam, loss=Trace_ELBO()) + + for _ in range(num_svi_iterations): + svi.step(x, y) + + +class ModuleListTester: + def setup(self, use_new_module_list_type: bool) -> None: + self.input_size = 1 + self.output_size = 1 + self.hidden_size = 3 + self.num_hidden_layers = 3 + self.use_new_module_list_type = use_new_module_list_type + + def get_position_indexing_modulelist_bnn(self) -> PositionIndexingModuleListBNN: + return PositionIndexingModuleListBNN( + self.input_size, + [self.hidden_size] * self.num_hidden_layers, + self.output_size, + self.use_new_module_list_type, + ) + + def get_slice_indexing_modulelist_bnn(self) -> SliceIndexingModuleListBNN: + return SliceIndexingModuleListBNN( + self.input_size, + [self.hidden_size] * self.num_hidden_layers, + self.output_size, + self.use_new_module_list_type, + ) + + def train_nested_bnn(self, module_getter: Callable[[], BNN]) -> None: + train_bnn( + NestedBNN( + [module_getter() for _ in range(2)], + use_new_module_list_type=self.use_new_module_list_type, + ), + self.input_size, + ) + + +class TestTorchModuleList(ModuleListTester): + def test_with_position_indexing(self) -> None: + self.setup(False) + self.train_nested_bnn(self.get_position_indexing_modulelist_bnn) + + def test_with_slice_indexing(self) -> None: + self.setup(False) + with pytest.raises(RuntimeError): + self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn) + + +class TestPyroModuleList(ModuleListTester): + def test_with_position_indexing(self) -> None: + self.setup(True) + self.train_nested_bnn(self.get_position_indexing_modulelist_bnn) + + def test_with_slice_indexing(self) -> None: + self.setup(True) + self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn) From 403f0d449ff3a90a24c50ac0c3bc0b20d8b79f81 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Sun, 17 Mar 2024 18:57:55 +0100 Subject: [PATCH 3/7] guard torch.jit_internal import in a try-catch block --- pyro/nn/module.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index a7b3c21ec4..33fe8b383e 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -17,7 +17,20 @@ from typing import Union import weakref from collections import OrderedDict, namedtuple -from torch._jit_internal import _copy_to_script_wrapper +import warnings + +try: + from torch._jit_internal import _copy_to_script_wrapper +except ImportError: + warnings.warn( + "Cannot find torch._jit_internal._copy_to_script_wrapper", ImportWarning + ) + + # Fall back to trivial decorator. + def _copy_to_script_wrapper(fn): + return fn + + import torch from torch.distributions import constraints, transform_to From 751649390206025ba77017f90b32cae5a72d6815 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Sun, 17 Mar 2024 19:05:26 +0100 Subject: [PATCH 4/7] add PyroModuleList to pyro_mixing_cache and update tests --- pyro/nn/module.py | 3 +++ tests/nn/test_module.py | 9 +++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 33fe8b383e..7c4736b753 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -867,3 +867,6 @@ def __getitem__( return torch.nn.ModuleList(list(self._modules.values())[idx]) else: return self._modules[self._get_abs_string_index(idx)] + + +_PyroModuleMeta._pyro_mixin_cache[torch.nn.ModuleList] = PyroModuleList diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 9809871c3f..5c3c41c518 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -1026,8 +1026,9 @@ def test_with_position_indexing(self) -> None: def test_with_slice_indexing(self) -> None: self.setup(False) - with pytest.raises(RuntimeError): - self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn) + # with pytest.raises(RuntimeError): + # error no longer gets raised + self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn) class TestPyroModuleList(ModuleListTester): @@ -1038,3 +1039,7 @@ def test_with_position_indexing(self) -> None: def test_with_slice_indexing(self) -> None: self.setup(True) self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn) + + +def test_module_list() -> None: + assert PyroModule[torch.nn.ModuleList] is pyro.nn.PyroModuleList From 5ffd8f7945800d01ce574e2a972cd71b82a5b692 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Sun, 17 Mar 2024 19:12:55 +0100 Subject: [PATCH 5/7] update inheritance sequence in PyroModuleList --- pyro/nn/module.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 7c4736b753..80dc833cab 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -14,7 +14,7 @@ """ import functools import inspect -from typing import Union +from typing import Iterable, Union import weakref from collections import OrderedDict, namedtuple import warnings @@ -853,10 +853,9 @@ def __set__(self, obj, value): # The scenario is rare but happend. # The fix could not be applied in torch directly, which is why we have to deal # with it here, see https://github.com/pytorch/pytorch/issues/121008 -class PyroModuleList(PyroModule, torch.nn.ModuleList): +class PyroModuleList(torch.nn.ModuleList, PyroModule): def __init__(self, modules): - PyroModule.__init__(self) - torch.nn.ModuleList.__init__(self, modules) + super().__init__(modules) @_copy_to_script_wrapper def __getitem__( From b23177fe03fe98ea587796b037f1482c6238259d Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Sun, 17 Mar 2024 21:07:20 +0100 Subject: [PATCH 6/7] black formatting --- pyro/nn/__init__.py | 2 +- pyro/nn/module.py | 4 ++-- tests/nn/test_module.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pyro/nn/__init__.py b/pyro/nn/__init__.py index 6da1eccff9..e55e7356f6 100644 --- a/pyro/nn/__init__.py +++ b/pyro/nn/__init__.py @@ -11,10 +11,10 @@ from pyro.nn.dense_nn import ConditionalDenseNN, DenseNN from pyro.nn.module import ( PyroModule, + PyroModuleList, PyroParam, PyroSample, pyro_method, - PyroModuleList, ) __all__ = [ diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 80dc833cab..ba71bfca3c 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -14,10 +14,10 @@ """ import functools import inspect -from typing import Iterable, Union +import warnings import weakref from collections import OrderedDict, namedtuple -import warnings +from typing import Union try: from torch._jit_internal import _copy_to_script_wrapper diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 5c3c41c518..5508bf9cbb 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -3,12 +3,14 @@ import io import math -from typing import Callable, Iterable import warnings +from typing import Callable, Iterable + import pytest import torch from torch import nn from torch.distributions import constraints, transform_to + import pyro import pyro.distributions as dist from pyro import poutine From 024635872e1fbe10a99106ba06c2e82462cd21d9 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Sun, 17 Mar 2024 23:44:21 +0100 Subject: [PATCH 7/7] black formatting --- pyro/nn/module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index ff3a3cf117..553b33d95d 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -28,6 +28,7 @@ def _copy_to_script_wrapper(fn): return fn + from collections import OrderedDict from dataclasses import dataclass from types import TracebackType @@ -916,6 +917,7 @@ def __set__(self, obj: object, value: Any) -> None: PyroModule[torch.nn.RNNBase]._flat_weights = _FlatWeightsDescriptor() # type: ignore[attr-defined] + # pyro module list # using pyro.nn.PyroModule[torch.nn.ModuleList] can cause issues when # slice-indexing nested PyroModuleLists, so we define a separate PyroModuleList