diff --git a/.github/workflows/pytorch.yml b/.github/workflows/pytorch.yml index e45fc19413..07428121cd 100644 --- a/.github/workflows/pytorch.yml +++ b/.github/workflows/pytorch.yml @@ -1,4 +1,3 @@ - name: pytorch on: diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index fc738652c4..15e6c9f0b4 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -15,6 +15,6 @@ # from merlin.models.torch.batch import Batch, Sequence -from merlin.models.torch.block import Block +from merlin.models.torch.block import Block, ParallelBlock -__all__ = ["Batch", "Block", "Sequence"] +__all__ = ["Batch", "Block", "ParallelBlock", "Sequence"] diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index e1ae80cfa9..4848d6cbe7 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -21,7 +21,7 @@ from torch import nn from merlin.models.torch.batch import Batch -from merlin.models.torch.container import BlockContainer +from merlin.models.torch.container import BlockContainer, BlockContainerDict class Block(BlockContainer): @@ -98,3 +98,240 @@ def copy(self) -> "Block": The copy of the current block. """ return deepcopy(self) + + +class ParallelBlock(Block): + """A base-class that calls it's modules in parallel. + + A ParallelBlock contains multiple branches that will be executed + in parallel. Each branch can contain multiple modules, and + the outputs of all branches are collected into a dictionary. + + If a branch returns a dictionary of tensors instead of a single tensor, + it will be flattened into the output dictionary. This ensures the output-type + is always Dict[str, torch.Tensor]. + + Example usage:: + >>> parallel_block = ParallelBlock({"a": nn.LazyLinear(2), "b": nn.LazyLinear(2)}) + >>> x = torch.randn(2, 2) + >>> output = parallel_block(x) + >>> # The output is a dictionary containing the output of each branch + >>> print(output) + { + 'a': tensor([[-0.0801, 0.0436], + [ 0.1235, -0.0318]]), + 'b': tensor([[ 0.0918, -0.0274], + [-0.0652, 0.0381]]) + } + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules to be contained in the block. + name : Optional[str], default = None + The name of the block. If None, no name is assigned. + """ + + def __init__( + self, + *inputs: Union[nn.Module, Dict[str, nn.Module]], + ): + pre = BlockContainer(name="pre") + branches = BlockContainerDict(*inputs) + post = BlockContainer(name="post") + + super().__init__() + + self.pre = pre + self.branches = branches + self.post = post + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ): + """Forward pass through the block. + + The steps are as follows: + 1. Pre-processing stage: Applies each module in the pre-processing stage sequentially. + 2. Branching stage: Applies each module in each branch sequentially. + 3. Post-processing stage: Applies each module in the post-processing stage sequentially. + + If a branch returns a dictionary of tensors instead of a single tensor, + it will be flattened into the output dictionary. This ensures the output-type + is always Dict[str, torch.Tensor]. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + The input tensor or dictionary of tensors. + batch : Optional[Batch], default=None + An optional batch of data. + + Returns + ------- + Dict[str, torch.Tensor] + The output tensors. + """ + for module in self.pre.values: + inputs = module(inputs, batch=batch) + + outputs = {} + for name, branch_container in self.branches.items(): + branch = inputs + for module in branch_container.values: + branch = module(branch, batch=batch) + + if isinstance(branch, torch.Tensor): + branch_dict = {name: branch} + elif torch.jit.isinstance(branch, Dict[str, torch.Tensor]): + branch_dict = branch + else: + raise TypeError( + f"Branch output must be a tensor or a dictionary of tensors. Got {type(branch)}" + ) + + for key in branch_dict.keys(): + if key in outputs: + raise RuntimeError(f"Duplicate output name: {key}") + + outputs.update(branch_dict) + + for module in self.post.values: + outputs = module(outputs, batch=batch) + + return outputs + + def append(self, module: nn.Module): + """Appends a module to the post-processing stage. + + Parameters + ---------- + module : nn.Module + The module to append. + + Returns + ------- + ParallelBlock + The current object itself. + """ + + self.post.append(module) + + return self + + def prepend(self, module: nn.Module): + """Prepends a module to the pre-processing stage. + + Parameters + ---------- + module : nn.Module + The module to prepend. + + Returns + ------- + ParallelBlock + The current object itself. + """ + + self.pre.prepend(module) + + return self + + def append_to(self, name: str, module: nn.Module): + """Appends a module to a specified branch. + + Parameters + ---------- + name : str + The name of the branch. + module : nn.Module + The module to append. + + Returns + ------- + ParallelBlock + The current object itself. + """ + + self.branches[name].append(module) + + return self + + def prepend_to(self, name: str, module: nn.Module): + """Prepends a module to a specified branch. + + Parameters + ---------- + name : str + The name of the branch. + module : nn.Module + The module to prepend. + + Returns + ------- + ParallelBlock + The current object itself. + """ + self.branches[name].prepend(module) + + return self + + def append_for_each(self, module: nn.Module, shared=False): + """Appends a module to each branch. + + Parameters + ---------- + module : nn.Module + The module to append. + shared : bool, default=False + If True, the same module is shared across all branches. + Otherwise a deep copy of the module is used in each branch. + + Returns + ------- + ParallelBlock + The current object itself. + """ + + self.branches.append_for_each(module, shared=shared) + + return self + + def prepend_for_each(self, module: nn.Module, shared=False): + """Prepends a module to each branch. + + Parameters + ---------- + module : nn.Module + The module to prepend. + shared : bool, default=False + If True, the same module is shared across all branches. + Otherwise a deep copy of the module is used in each branch. + + Returns + ------- + ParallelBlock + The current object itself. + """ + + self.branches.prepend_for_each(module, shared=shared) + + return self + + def __getitem__(self, idx: Union[slice, int, str]): + if isinstance(idx, str) and idx in self.branches: + return self.branches[idx] + + if idx == 0: + return self.pre + + if idx == -1 or idx == 2: + return self.post + + raise IndexError(f"Index {idx} is out of range for {self.__class__.__name__}") + + def __len__(self): + return len(self.branches) + + def __contains__(self, name): + return name in self.branches diff --git a/merlin/models/torch/container.py b/merlin/models/torch/container.py index 458183603b..eca39043bb 100644 --- a/merlin/models/torch/container.py +++ b/merlin/models/torch/container.py @@ -14,7 +14,9 @@ # limitations under the License. # -from typing import Iterator, Optional, Union +from copy import deepcopy +from functools import reduce +from typing import Dict, Iterator, Optional, Union from torch import nn from torch._jit_internal import _copy_to_script_wrapper @@ -149,3 +151,132 @@ def __repr__(self) -> str: def _get_name(self) -> str: return super()._get_name() if self._name is None else self._name + + +class BlockContainerDict(nn.ModuleDict): + """A container class for PyTorch `nn.Module` that allows for manipulation and traversal + of multiple sub-modules as if they were a dictionary. The modules are automatically wrapped + in a TorchScriptWrapper for TorchScript compatibility. + + Parameters + ---------- + *inputs : nn.Module + One or more PyTorch modules to be added to the container. + name : Optional[str] + An optional name for the BlockContainer. + """ + + def __init__( + self, *inputs: Union[nn.Module, Dict[str, nn.Module]], name: Optional[str] = None + ) -> None: + if not inputs: + inputs = [{}] + + if all(isinstance(x, dict) for x in inputs): + modules = reduce(lambda a, b: dict(a, **b), inputs) # type: ignore + + super().__init__(modules) + self._name: str = name + + def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict": + """Appends a module to a specified name. + + Parameters + ---------- + name : str + The name of the branch. + module : nn.Module + The module to append. + + Returns + ------- + BlockContainerDict + The current object itself. + """ + + self._modules[name].append(module) + + return self + + def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict": + """Prepends a module to a specified name. + + Parameters + ---------- + name : str + The name of the branch. + module : nn.Module + The module to prepend. + + Returns + ------- + BlockContainerDict + The current object itself. + """ + + self._modules[name].prepend(module) + + return self + + # Append to all branches, optionally copying + def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict": + """Appends a module to each branch. + + Parameters + ---------- + module : nn.Module + The module to append to each branch. + shared : bool, default=False + If True, the same module is shared across all elements. + Otherwise a deep copy of the module is used in each element. + + Returns + ------- + BlockContainerDict + The current object itself. + """ + + for branch in self.values(): + _module = module if shared else deepcopy(module) + branch.append(_module) + + return self + + def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict": + """Prepends a module to each branch. + + Parameters + ---------- + module : nn.Module + The module to prepend to each branch. + shared : bool, default=False + If True, the same module is shared across all elements. + Otherwise a deep copy of the module is used in each element. + + Returns + ------- + BlockContainerDict + The current object itself. + """ + + for branch in self.values(): + _module = module if shared else deepcopy(module) + branch.prepend(_module) + + return self + + # This assumes same branches, we append it's content to each branch + # def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict": + # ... + + def add_module(self, name: str, module: Optional[nn.Module]) -> None: + if module and not isinstance(module, BlockContainer): + module = BlockContainer(module, name=name[0].upper() + name[1:]) + + return super().add_module(name, module) + + def unwrap(self) -> Dict[str, nn.ModuleList]: + return {k: v.unwrap() for k, v in self.items()} + + def _get_name(self) -> str: + return super()._get_name() if self._name is None else self._name diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index affeed200e..63ac6d36c5 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Dict, Tuple import pytest import torch from torch import nn from merlin.models.torch.batch import Batch -from merlin.models.torch.block import Block +from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.utils import module_utils @@ -28,6 +30,16 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return inputs + 1 +class PlusOneDict(nn.Module): + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: v + 1 for k, v in inputs.items()} + + +class PlusOneTuple(nn.Module): + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return inputs + 1, inputs + 1 + + class TestBlock: def test_identity(self): block = Block() @@ -69,3 +81,111 @@ def test_repeat(self): with pytest.raises(ValueError, match="n must be greater than 0"): block.repeat(0) + + +class TestParallelBlock: + def test_init(self): + pb = ParallelBlock({"test": PlusOne()}) + assert isinstance(pb, ParallelBlock) + assert isinstance(pb.pre, BlockContainer) + assert isinstance(pb.branches, BlockContainerDict) + assert isinstance(pb.post, BlockContainer) + + def test_init_list_of_dict(self): + pb = ParallelBlock(({"test": PlusOne()})) + assert len(pb) == 1 + assert "test" in pb + + def test_forward(self): + inputs = torch.randn(1, 3) + pb = ParallelBlock({"test": PlusOne()}) + outputs = module_utils.module_test(pb, inputs) + assert isinstance(outputs, dict) + assert "test" in outputs + + def test_forward_dict(self): + inputs = {"a": torch.randn(1, 3)} + pb = ParallelBlock({"test": PlusOneDict()}) + outputs = module_utils.module_test(pb, inputs) + assert isinstance(outputs, dict) + assert "a" in outputs + + def test_forward_dict_duplicate(self): + inputs = {"a": torch.randn(1, 3)} + pb = ParallelBlock({"1": PlusOneDict(), "2": PlusOneDict()}) + + with pytest.raises(RuntimeError): + pb(inputs) + + def test_forward_tuple(self): + inputs = torch.randn(1, 3) + pb = ParallelBlock({"test": PlusOneTuple()}) + with pytest.raises(RuntimeError): + module_utils.module_test(pb, inputs) + + def test_append(self): + module = PlusOneDict() + pb = ParallelBlock({"test": PlusOne()}) + pb.append(module) + assert len(pb.post._modules) == 1 + + assert pb[-1][0] == module + assert pb[2][0] == module + + module_utils.module_test(pb, torch.randn(1, 3)) + + def test_prepend(self): + module = PlusOne() + pb = ParallelBlock({"test": module}) + pb.prepend(module) + assert len(pb.pre._modules) == 1 + + assert pb[0][0] == module + + module_utils.module_test(pb, torch.randn(1, 3)) + + def test_append_to(self): + module = nn.Module() + pb = ParallelBlock({"test": module}) + pb.append_to("test", module) + assert len(pb["test"]) == 2 + + def test_prepend_to(self): + module = nn.Module() + pb = ParallelBlock({"test": module}) + pb.prepend_to("test", module) + assert len(pb["test"]) == 2 + + def test_append_for_each(self): + module = nn.Module() + pb = ParallelBlock({"a": module, "b": module}) + pb.append_for_each(module) + assert len(pb["a"]) == 2 + assert len(pb["b"]) == 2 + assert pb["a"][-1] != pb["b"][-1] + + pb.append_for_each(module, shared=True) + assert len(pb["a"]) == 3 + assert len(pb["b"]) == 3 + assert pb["a"][-1] == pb["b"][-1] + + def test_prepend_for_each(self): + module = nn.Module() + pb = ParallelBlock({"a": module, "b": module}) + pb.prepend_for_each(module) + assert len(pb["a"]) == 2 + assert len(pb["b"]) == 2 + assert pb["a"][0] != pb["b"][0] + + pb.prepend_for_each(module, shared=True) + assert len(pb["a"]) == 3 + assert len(pb["b"]) == 3 + assert pb["a"][0] == pb["b"][0] + + def test_getitem(self): + module = nn.Module() + pb = ParallelBlock({"test": module}) + assert isinstance(pb["test"], BlockContainer) + + with pytest.raises(IndexError): + pb["invalid_key"] diff --git a/tests/unit/torch/test_container.py b/tests/unit/torch/test_container.py index c005c0bfb1..ddf2716a6d 100644 --- a/tests/unit/torch/test_container.py +++ b/tests/unit/torch/test_container.py @@ -17,7 +17,7 @@ import pytest import torch.nn as nn -from merlin.models.torch.container import BlockContainer +from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.utils import torchscript_utils @@ -129,3 +129,59 @@ def test_add_module(self): def test_get_name(self): assert self.block_container._get_name() == "test_container" + + +class TestBlockContainerDict: + def setup_method(self): + self.module = nn.Module() + self.container = BlockContainerDict({"test": self.module}, name="test") + self.block_container = BlockContainer(name="test_container") + + def test_init(self): + assert isinstance(self.container, BlockContainerDict) + assert self.container._get_name() == "test" + assert isinstance(self.container.unwrap()["test"], nn.ModuleList) + + def test_empty(self): + container = BlockContainerDict() + assert len(container) == 0 + + def test_not_module(self): + with pytest.raises(ValueError): + BlockContainerDict({"test": "not a module"}) + + def test_append_to(self): + self.container.append_to("test", self.module) + assert "test" in self.container._modules + + def test_prepend_to(self): + self.container.prepend_to("test", self.module) + assert "test" in self.container._modules + + def test_append_for_each(self): + container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()}) + + to_add = nn.Module() + container.append_for_each(to_add) + assert len(container["a"]) == 2 + assert len(container["b"]) == 2 + assert container["a"][-1] != container["b"][-1] + + container.append_for_each(to_add, shared=True) + assert len(container["a"]) == 3 + assert len(container["b"]) == 3 + assert container["a"][-1] == container["b"][-1] + + def test_prepend_for_each(self): + container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()}) + + to_add = nn.Module() + container.prepend_for_each(to_add) + assert len(container["a"]) == 2 + assert len(container["b"]) == 2 + assert container["a"][0] != container["b"][0] + + container.prepend_for_each(to_add, shared=True) + assert len(container["a"]) == 3 + assert len(container["b"]) == 3 + assert container["a"][0] == container["b"][0]