Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ParallelBlock #1088

Merged
merged 10 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/pytorch.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

name: pytorch

on:
Expand Down
4 changes: 2 additions & 2 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#

from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block
from merlin.models.torch.block import Block, ParallelBlock

__all__ = ["Batch", "Block", "Sequence"]
__all__ = ["Batch", "Block", "ParallelBlock", "Sequence"]
239 changes: 238 additions & 1 deletion merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import nn

from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer
from merlin.models.torch.container import BlockContainer, BlockContainerDict


class Block(BlockContainer):
Expand Down Expand Up @@ -98,3 +98,240 @@ def copy(self) -> "Block":
The copy of the current block.
"""
return deepcopy(self)


class ParallelBlock(Block):
"""A base-class that calls it's modules in parallel.

A ParallelBlock contains multiple branches that will be executed
in parallel. Each branch can contain multiple modules, and
the outputs of all branches are collected into a dictionary.

If a branch returns a dictionary of tensors instead of a single tensor,
it will be flattened into the output dictionary. This ensures the output-type
is always Dict[str, torch.Tensor].

Example usage::
>>> parallel_block = ParallelBlock({"a": nn.LazyLinear(2), "b": nn.LazyLinear(2)})
>>> x = torch.randn(2, 2)
>>> output = parallel_block(x)
>>> # The output is a dictionary containing the output of each branch
>>> print(output)
{
'a': tensor([[-0.0801, 0.0436],
[ 0.1235, -0.0318]]),
'b': tensor([[ 0.0918, -0.0274],
[-0.0652, 0.0381]])
}

Parameters
----------
*module : nn.Module
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
"""

def __init__(
self,
*inputs: Union[nn.Module, Dict[str, nn.Module]],
):
pre = BlockContainer(name="pre")
branches = BlockContainerDict(*inputs)
post = BlockContainer(name="post")

super().__init__()

self.pre = pre
self.branches = branches
self.post = post

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
"""Forward pass through the block.

The steps are as follows:
1. Pre-processing stage: Applies each module in the pre-processing stage sequentially.
2. Branching stage: Applies each module in each branch sequentially.
3. Post-processing stage: Applies each module in the post-processing stage sequentially.

If a branch returns a dictionary of tensors instead of a single tensor,
it will be flattened into the output dictionary. This ensures the output-type
is always Dict[str, torch.Tensor].

Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input tensor or dictionary of tensors.
batch : Optional[Batch], default=None
An optional batch of data.

Returns
-------
Dict[str, torch.Tensor]
The output tensors.
"""
for module in self.pre.values:
inputs = module(inputs, batch=batch)

outputs = {}
for name, branch_container in self.branches.items():
branch = inputs
for module in branch_container.values:
branch = module(branch, batch=batch)

if isinstance(branch, torch.Tensor):
branch_dict = {name: branch}
elif torch.jit.isinstance(branch, Dict[str, torch.Tensor]):
branch_dict = branch
else:
raise TypeError(
f"Branch output must be a tensor or a dictionary of tensors. Got {type(branch)}"
)

for key in branch_dict.keys():
if key in outputs:
raise RuntimeError(f"Duplicate output name: {key}")

outputs.update(branch_dict)

for module in self.post.values:
outputs = module(outputs, batch=batch)

return outputs

def append(self, module: nn.Module):
"""Appends a module to the post-processing stage.

Parameters
----------
module : nn.Module
The module to append.

Returns
-------
ParallelBlock
The current object itself.
"""

self.post.append(module)

return self

def prepend(self, module: nn.Module):
"""Prepends a module to the pre-processing stage.

Parameters
----------
module : nn.Module
The module to prepend.

Returns
-------
ParallelBlock
The current object itself.
"""

self.pre.prepend(module)

return self

def append_to(self, name: str, module: nn.Module):
"""Appends a module to a specified branch.

Parameters
----------
name : str
The name of the branch.
module : nn.Module
The module to append.

Returns
-------
ParallelBlock
The current object itself.
"""

self.branches[name].append(module)

return self

def prepend_to(self, name: str, module: nn.Module):
"""Prepends a module to a specified branch.

Parameters
----------
name : str
The name of the branch.
module : nn.Module
The module to prepend.

Returns
-------
ParallelBlock
The current object itself.
"""
self.branches[name].prepend(module)

return self

def append_for_each(self, module: nn.Module, shared=False):
"""Appends a module to each branch.

Parameters
----------
module : nn.Module
The module to append.
shared : bool, default=False
If True, the same module is shared across all branches.
Otherwise a deep copy of the module is used in each branch.

Returns
-------
ParallelBlock
The current object itself.
"""

self.branches.append_for_each(module, shared=shared)

return self

def prepend_for_each(self, module: nn.Module, shared=False):
"""Prepends a module to each branch.

Parameters
----------
module : nn.Module
The module to prepend.
shared : bool, default=False
If True, the same module is shared across all branches.
Otherwise a deep copy of the module is used in each branch.

Returns
-------
ParallelBlock
The current object itself.
"""

self.branches.prepend_for_each(module, shared=shared)

return self

def __getitem__(self, idx: Union[slice, int, str]):
if isinstance(idx, str) and idx in self.branches:
return self.branches[idx]

if idx == 0:
return self.pre

if idx == -1 or idx == 2:
return self.post

raise IndexError(f"Index {idx} is out of range for {self.__class__.__name__}")

def __len__(self):
return len(self.branches)

def __contains__(self, name):
return name in self.branches
Loading