Skip to content

Commit

Permalink
Adding ParallelBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed May 12, 2023
1 parent cbf7e95 commit 77ca69b
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 11 deletions.
4 changes: 2 additions & 2 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#

from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block
from merlin.models.torch.block import Block, ParallelBlock

__all__ = ["Batch", "Block", "Sequence"]
__all__ = ["Batch", "Block", "ParallelBlock", "Sequence"]
99 changes: 98 additions & 1 deletion merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import nn

from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer
from merlin.models.torch.container import BlockContainer, BlockContainerDict


class Block(BlockContainer):
Expand Down Expand Up @@ -98,3 +98,100 @@ def copy(self) -> "Block":
The copy of the current block.
"""
return deepcopy(self)


class ParallelBlock(Block):
def __init__(
self,
*inputs: Union[nn.Module, Dict[str, nn.Module]],
# TODO: Add agg
):
pre = BlockContainer(name="pre")
branches = BlockContainerDict(*inputs)
post = BlockContainer(name="post")

super().__init__()

self.pre = pre
self.branches = branches
self.post = post

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
for module in self.pre.values:
inputs = module(inputs, batch=batch)

outputs = {}
for name, branch_container in self.branches.items():
branch = inputs
for module in branch_container.values:
branch = module(branch, batch=batch)

if isinstance(branch, torch.Tensor):
branch_dict = {name: branch}
elif torch.jit.isinstance(branch, Dict[str, torch.Tensor]):
branch_dict = branch
else:
raise TypeError(
f"Branch output must be a tensor or a dictionary of tensors. Got {type(branch)}"
)

for key in branch_dict.keys():
if key in outputs:
raise RuntimeError(f"Duplicate output name: {key}")

outputs.update(branch_dict)

for module in self.post.values:
outputs = module(outputs, batch=batch)

return outputs

def append(self, module: nn.Module):
self.post.append(module)

return self

def prepend(self, module: nn.Module):
self.pre.prepend(module)

return self

def append_to(self, name: str, module: nn.Module):
self.branches[name].append(module)

return self

def prepend_to(self, name: str, module: nn.Module):
self.branches[name].prepend(module)

return self

def append_for_each(self, module: nn.Module, shared=False):
self.branches.append_for_each(module, shared=shared)

return self

def prepend_for_each(self, module: nn.Module, shared=False):
self.branches.prepend_for_each(module, shared=shared)

return self

def __getitem__(self, idx: Union[slice, int]):
if isinstance(idx, str) and idx in self.branches:
return self.branches[idx]

if idx == 0:
return self.pre

if idx == -1 or idx == 2:
return self.post

raise IndexError(f"Index {idx} is out of range for {self.__class__.__name__}")

def __len__(self):
return len(self.branches)

def __contains__(self, name):
return name in self.branches
2 changes: 0 additions & 2 deletions merlin/models/torch/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def __init__(
if not inputs:
inputs = [{}]

if isinstance(inputs, tuple) and len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
modules = inputs[0]
if all(isinstance(x, dict) for x in inputs):
modules = reduce(lambda a, b: dict(a, **b), inputs) # type: ignore

Expand Down
122 changes: 121 additions & 1 deletion tests/unit/torch/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Tuple

import pytest
import torch
from torch import nn

from merlin.models.torch.batch import Batch
from merlin.models.torch.block import Block
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.utils import module_utils


Expand All @@ -28,6 +30,16 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return inputs + 1


class PlusOneDict(nn.Module):
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {k: v + 1 for k, v in inputs.items()}


class PlusOneTuple(nn.Module):
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return inputs + 1, inputs + 1


class TestBlock:
def test_identity(self):
block = Block()
Expand Down Expand Up @@ -69,3 +81,111 @@ def test_repeat(self):

with pytest.raises(ValueError, match="n must be greater than 0"):
block.repeat(0)


class TestParallelBlock:
def test_init(self):
pb = ParallelBlock({"test": PlusOne()})
assert isinstance(pb, ParallelBlock)
assert isinstance(pb.pre, BlockContainer)
assert isinstance(pb.branches, BlockContainerDict)
assert isinstance(pb.post, BlockContainer)

def test_init_list_of_dict(self):
pb = ParallelBlock(({"test": PlusOne()}))
assert len(pb) == 1
assert "test" in pb

def test_forward(self):
inputs = torch.randn(1, 3)
pb = ParallelBlock({"test": PlusOne()})
outputs = module_utils.module_test(pb, inputs)
assert isinstance(outputs, dict)
assert "test" in outputs

def test_forward_dict(self):
inputs = {"a": torch.randn(1, 3)}
pb = ParallelBlock({"test": PlusOneDict()})
outputs = module_utils.module_test(pb, inputs)
assert isinstance(outputs, dict)
assert "a" in outputs

def test_forward_dict_duplicate(self):
inputs = {"a": torch.randn(1, 3)}
pb = ParallelBlock({"1": PlusOneDict(), "2": PlusOneDict()})

with pytest.raises(RuntimeError):
pb(inputs)

def test_forward_tuple(self):
inputs = torch.randn(1, 3)
pb = ParallelBlock({"test": PlusOneTuple()})
with pytest.raises(RuntimeError):
module_utils.module_test(pb, inputs)

def test_append(self):
module = PlusOneDict()
pb = ParallelBlock({"test": PlusOne()})
pb.append(module)
assert len(pb.post._modules) == 1

assert pb[-1][0] == module
assert pb[2][0] == module

module_utils.module_test(pb, torch.randn(1, 3))

def test_prepend(self):
module = PlusOne()
pb = ParallelBlock({"test": module})
pb.prepend(module)
assert len(pb.pre._modules) == 1

assert pb[0][0] == module

module_utils.module_test(pb, torch.randn(1, 3))

def test_append_to(self):
module = nn.Module()
pb = ParallelBlock({"test": module})
pb.append_to("test", module)
assert len(pb["test"]) == 2

def test_prepend_to(self):
module = nn.Module()
pb = ParallelBlock({"test": module})
pb.prepend_to("test", module)
assert len(pb["test"]) == 2

def test_append_for_each(self):
module = nn.Module()
pb = ParallelBlock({"a": module, "b": module})
pb.append_for_each(module)
assert len(pb["a"]) == 2
assert len(pb["b"]) == 2
assert pb["a"][-1] != pb["b"][-1]

pb.append_for_each(module, shared=True)
assert len(pb["a"]) == 3
assert len(pb["b"]) == 3
assert pb["a"][-1] == pb["b"][-1]

def test_prepend_for_each(self):
module = nn.Module()
pb = ParallelBlock({"a": module, "b": module})
pb.prepend_for_each(module)
assert len(pb["a"]) == 2
assert len(pb["b"]) == 2
assert pb["a"][0] != pb["b"][0]

pb.prepend_for_each(module, shared=True)
assert len(pb["a"]) == 3
assert len(pb["b"]) == 3
assert pb["a"][0] == pb["b"][0]

def test_getitem(self):
module = nn.Module()
pb = ParallelBlock({"test": module})
assert isinstance(pb["test"], BlockContainer)

with pytest.raises(IndexError):
pb["invalid_key"]
5 changes: 0 additions & 5 deletions tests/unit/torch/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,6 @@ def test_empty(self):
container = BlockContainerDict()
assert len(container) == 0

def test_list_of_dict(self):
container = BlockContainerDict(({"test": self.module}))
assert len(container) == 1
assert "test" in container

def test_not_module(self):
with pytest.raises(ValueError):
BlockContainerDict({"test": "not a module"})
Expand Down

0 comments on commit 77ca69b

Please sign in to comment.