Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Link #1091

Merged
merged 6 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.registry import registry
from merlin.models.utils.registry import RegistryMixin

Expand Down Expand Up @@ -65,7 +66,7 @@ def forward(

return inputs

def repeat(self, n: int = 1, name=None) -> "Block":
def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block":
"""
Creates a new block by repeating the current block `n` times.
Each repetition is a deep copy of the current block.
Expand All @@ -89,6 +90,9 @@ def repeat(self, n: int = 1, name=None) -> "Block":
raise ValueError("n must be greater than 0")

repeats = [self.copy() for _ in range(n - 1)]
if link:
parsed_link = Link.parse(link)
repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats]

return Block(self, *repeats, name=name)

Expand Down Expand Up @@ -205,7 +209,7 @@ def forward(

return outputs

def append(self, module: nn.Module):
def append(self, module: nn.Module, link: Optional[LinkType] = None):
"""Appends a module to the post-processing stage.

Parameters
Expand All @@ -219,29 +223,16 @@ def append(self, module: nn.Module):
The current object itself.
"""

self.post.append(module)
self.post.append(module, link=link)

return self

def prepend(self, module: nn.Module):
"""Prepends a module to the pre-processing stage.

Parameters
----------
module : nn.Module
The module to prepend.

Returns
-------
ParallelBlock
The current object itself.
"""

self.pre.prepend(module)

return self

def append_to(self, name: str, module: nn.Module):
def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
"""Appends a module to a specified branch.

Parameters
Expand All @@ -257,11 +248,11 @@ def append_to(self, name: str, module: nn.Module):
The current object itself.
"""

self.branches[name].append(module)
self.branches[name].append(module, link=link)

return self

def prepend_to(self, name: str, module: nn.Module):
def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
"""Prepends a module to a specified branch.

Parameters
Expand All @@ -276,11 +267,11 @@ def prepend_to(self, name: str, module: nn.Module):
ParallelBlock
The current object itself.
"""
self.branches[name].prepend(module)
self.branches[name].prepend(module, link=link)

return self

def append_for_each(self, module: nn.Module, shared=False):
def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
"""Appends a module to each branch.

Parameters
Expand All @@ -297,11 +288,11 @@ def append_for_each(self, module: nn.Module, shared=False):
The current object itself.
"""

self.branches.append_for_each(module, shared=shared)
self.branches.append_for_each(module, shared=shared, link=link)

return self

def prepend_for_each(self, module: nn.Module, shared=False):
def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
"""Prepends a module to each branch.

Parameters
Expand All @@ -318,7 +309,7 @@ def prepend_for_each(self, module: nn.Module, shared=False):
The current object itself.
"""

self.branches.prepend_for_each(module, shared=shared)
self.branches.prepend_for_each(module, shared=shared, link=link)

return self

Expand Down
106 changes: 83 additions & 23 deletions merlin/models/torch/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import nn
from torch._jit_internal import _copy_to_script_wrapper

from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.utils import torchscript_utils


Expand All @@ -46,37 +47,52 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None):

self._name: str = name

def append(self, module: nn.Module):
def append(self, module: nn.Module, link: Optional[Link] = None):
"""Appends a given module to the end of the list.

Parameters
----------
module : nn.Module
The PyTorch module to be appended.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
self
"""
self.values.append(self.wrap_module(module))
_module = self._check_link(module, link=link)
self.values.append(self.wrap_module(_module))

return self

def prepend(self, module: nn.Module):
def prepend(self, module: nn.Module, link: Optional[Link] = None):
"""Prepends a given module to the beginning of the list.

Parameters
----------
module : nn.Module
The PyTorch module to be prepended.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
self
"""
return self.insert(0, module)
return self.insert(0, module, link=link)

def insert(self, index: int, module: nn.Module):
def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
"""Inserts a given module at the specified index.

Parameters
Expand All @@ -85,13 +101,20 @@ def insert(self, index: int, module: nn.Module):
The index at which the module is to be inserted.
module : nn.Module
The PyTorch module to be inserted.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
self
"""

self.values.insert(index, self.wrap_module(module))
_module = self._check_link(module, link=link)
self.values.insert(index, self.wrap_module(_module))

return self

Expand Down Expand Up @@ -152,6 +175,15 @@ def __repr__(self) -> str:
def _get_name(self) -> str:
return super()._get_name() if self._name is None else self._name

def _check_link(self, module: nn.Module, link: Optional[LinkType] = None) -> nn.Module:
if link:
linked_module: Link = Link.parse(link)
linked_module.setup_link(module)

return linked_module

return module


class BlockContainerDict(nn.ModuleDict):
"""A container class for PyTorch `nn.Module` that allows for manipulation and traversal
Expand All @@ -178,7 +210,9 @@ def __init__(
super().__init__(modules)
self._name: str = name

def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
def append_to(
self, name: str, module: nn.Module, link: Optional[LinkType] = None
) -> "BlockContainerDict":
"""Appends a module to a specified name.

Parameters
Expand All @@ -187,18 +221,27 @@ def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
The name of the branch.
module : nn.Module
The module to append.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
BlockContainerDict
The current object itself.
"""

self._modules[name].append(module)
self._modules[name].append(module, link=link)

return self

def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
def prepend_to(
self, name: str, module: nn.Module, link: Optional[LinkType] = None
) -> "BlockContainerDict":
"""Prepends a module to a specified name.

Parameters
Expand All @@ -207,19 +250,25 @@ def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
The name of the branch.
module : nn.Module
The module to prepend.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
BlockContainerDict
The current object itself.
"""

self._modules[name].prepend(module)
self._modules[name].prepend(module, link=link)

return self

# Append to all branches, optionally copying
def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
def append_for_each(
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
) -> "BlockContainerDict":
"""Appends a module to each branch.

Parameters
Expand All @@ -229,6 +278,13 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
shared : bool, default=False
If True, the same module is shared across all elements.
Otherwise a deep copy of the module is used in each element.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
Expand All @@ -238,11 +294,13 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic

for branch in self.values():
_module = module if shared else deepcopy(module)
branch.append(_module)
branch.append(_module, link=link)

return self

def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
def prepend_for_each(
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
) -> "BlockContainerDict":
"""Prepends a module to each branch.

Parameters
Expand All @@ -252,23 +310,25 @@ def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDi
shared : bool, default=False
If True, the same module is shared across all elements.
Otherwise a deep copy of the module is used in each element.
link : Optional[LinkType]
The link to use for the module. If None, no link is used.
This can either be a Module or a string, options are:
- "residual": Adds a residual connection to the module.
- "shortcut": Adds a shortcut connection to the module.
- "shortcut-concat": Adds a shortcut connection by concatenating
the input and output.

Returns
-------
BlockContainerDict
The current object itself.
"""

for branch in self.values():
_module = module if shared else deepcopy(module)
branch.prepend(_module)
branch.prepend(_module, link=link)

return self

# This assumes same branches, we append it's content to each branch
# def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict":
# ...

def add_module(self, name: str, module: Optional[nn.Module]) -> None:
if module and not isinstance(module, BlockContainer):
module = BlockContainer(module, name=name[0].upper() + name[1:])
Expand Down
Loading