You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I already have made a PR but it seems like it just does not get any attention without an issue related to it. So here is the issue to it:
While using "nested" PyroModule's: A PyroModule with a PyroModule[torch.nn.ModuleList] argument, containing another such module (see the test implementation in the PR for a more detailed example), I encountered some errors related to the existence of multiple sample sites with the same name.
I spend some time on tracing the issue and found that the RuntimeError was caused by an unlucky combination of PyroModule and torch.nn.ModuleList:
First: why using torch.nn.ModuleList? Well I wanted to model a PyroModule that owns a list of sub-pyromodules. My idea was to replace linear = PyroModule[Linear](5, 2) from the modules example by PyroModule[torch.nn.ModuleList](...).
That works fine if there is a
which calls self.__class__, which, means that for an object of type PyroModule[torch.nn.ModuleList], it calls the PyroModule.__init__ function, without the context of the parent module. This results in overwriting the names of the module's submodules and their ._pyro_name attributes, and because of that, during sampling, sample sites may not be unique anymore.
I see two possible fixes for this:
I have introduced a pyro.nn.PyroModuleList class in this PR that inherits from torch.nn.ModuleList can overwrites the problematic __getitem__ function
We may add some example/documentation on this issue, explicitly warning about using PyroModule[torch.nn.ModuleList] in combination with slice-indexing (feels a little unsafe to me)
Maybe some has another (potentially better) idea of how to deal with this?
I know that this is kind of a "special purpose" usage and may not affect basic pyro usage, but in particular the way the modules example motivates the usage of PyroModule[torch.nn.Something] could, imho, quickly lure other users into this (and it took me quite some time to find the root issue).
Please let me know what you think about this PR, and whether it needs updates or clarification.
Best regards
Martin
The text was updated successfully, but these errors were encountered:
quick follow up to your @fritzo question above: I don't think slice-inedexing is a very common use case. Also, the line that causes the loss of the parent-module's _.pyro_name is necessary from a torch-context, and replacing it with the torch.nn.ModuleList contructor, as in this PR is not an option for general torch, as noted here. I guess with the fix applied in the PR, we should be safe for pyro use cases.
Hi,
I already have made a PR but it seems like it just does not get any attention without an issue related to it. So here is the issue to it:
While using "nested"
PyroModule
's: APyroModule
with aPyroModule[torch.nn.ModuleList]
argument, containing another such module (see the test implementation in the PR for a more detailed example), I encountered some errors related to the existence of multiple sample sites with the same name.I spend some time on tracing the issue and found that the
RuntimeError
was caused by an unlucky combination ofPyroModule
andtorch.nn.ModuleList
:First: why using
torch.nn.ModuleList
? Well I wanted to model aPyroModule
that owns a list of sub-pyromodules. My idea was to replacelinear = PyroModule[Linear](5, 2)
from the modules example byPyroModule[torch.nn.ModuleList](...)
.That works fine if there is a
but can fail if there is a nested structure, like
The "can fail" could be resolved to the following different types if accessing the
self.my_submodule
argument:self.my_submodule[0](...)
worked fine, even for nested modulesself.my_submodule[:-1](...)
only works if there is a singleMyPyroModule()
and not aMyPyroModule(MyPyroModule(...))
The cause is this line in the
torch.nn.ModuleList
class:which calls
self.__class__
, which, means that for an object of typePyroModule[torch.nn.ModuleList]
, it calls thePyroModule.__init__
function, without the context of the parent module. This results in overwriting the names of the module's submodules and their._pyro_name
attributes, and because of that, during sampling, sample sites may not be unique anymore.I see two possible fixes for this:
pyro.nn.PyroModuleList
class in this PR that inherits fromtorch.nn.ModuleList
can overwrites the problematic__getitem__
functionPyroModule[torch.nn.ModuleList]
in combination with slice-indexing (feels a little unsafe to me)I know that this is kind of a "special purpose" usage and may not affect basic
pyro
usage, but in particular the way the modules example motivates the usage ofPyroModule[torch.nn.Something]
could, imho, quickly lure other users into this (and it took me quite some time to find the root issue).Please let me know what you think about this PR, and whether it needs updates or clarification.
Best regards
Martin
The text was updated successfully, but these errors were encountered: