Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement ensemble reduce #1360

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,27 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import LSTMModule, NormalParamWrapper, SafeModule, TanhNormal
from torchrl.modules import (
EnsembleModule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from tensordict

LSTMModule,
NormalParamWrapper,
SafeModule,
TanhNormal,
)
from torchrl.modules.tensordict_module.common import (
ensure_tensordict_compatible,
is_tensordict_compatible,
VmapModule,
)
from torchrl.modules.tensordict_module.ensemble import Reduce
from torchrl.modules.tensordict_module.probabilistic import (
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from torchrl.modules.tensordict_module.sequence import SafeSequential
from torchrl.modules.tensordict_module.sequence import (
SafeSequential,
TensorDictSequential,
)


_has_functorch = False
Expand All @@ -41,6 +51,76 @@
pass


class TestEnsembleReduce:
def test_reduce(self):
module = TensorDictModule(nn.Linear(2, 3), in_keys=["bork"], out_keys=["dork"])
m0 = EnsembleModule(module, num_copies=2)
m1 = Reduce(["dork"], ["spork"])
seq = TensorDictSequential(m0, m1)
td = TensorDict({"bork": torch.randn(5, 2)}, batch_size=[5])
out = seq(td)
assert "spork" in out.keys()
assert out.shape == (5,)


class TestEnsembleModule:
def test_init(self):
"""Ensure that we correctly initialize copied weights s.t. they are not identical
to the original weights."""
torch.manual_seed(0)
module = TensorDictModule(
nn.Sequential(
nn.Linear(2, 3),
nn.ReLU(),
nn.Linear(3, 1),
),
in_keys=["a"],
out_keys=["b"],
)
mod = EnsembleModule(module, num_copies=2)
for param in mod.params:
p0, p1 = param.unbind(0)
assert not torch.allclose(
p0, p1
), f"Ensemble params were not initialized correctly {p0}, {p1}"

def test_siso_forward(self):
"""Ensure that forward works for a single input and output"""
module = TensorDictModule(
nn.Sequential(
nn.Linear(2, 3),
nn.ReLU(),
),
in_keys=["bork"],
out_keys=["dork"],
)
mod = EnsembleModule(module, num_copies=2)
td = TensorDict({"bork": torch.randn(5, 2)}, batch_size=[5])
out = mod(td)
assert "dork" in out.keys(), "Ensemble forward failed to write keys"
assert out["dork"].shape == torch.Size(
[2, 5, 3]
), "Ensemble forward failed to expand input"
outs = out["dork"].unbind(0)
assert not torch.allclose(outs[0], outs[1]), "Outputs should be different"

def test_chained_ensembles(self):
"""Ensure that the expand_input argument works"""
module = TensorDictModule(nn.Linear(2, 3), in_keys=["bork"], out_keys=["dork"])
next_module = TensorDictModule(
nn.Linear(3, 1), in_keys=["dork"], out_keys=["spork"]
)
e0 = EnsembleModule(module, num_copies=4, expand_input=True)
e1 = EnsembleModule(next_module, num_copies=4, expand_input=False)
seq = TensorDictSequential(e0, e1)
td = TensorDict({"bork": torch.randn(5, 2)}, batch_size=[5])
out = seq(td)
assert "spork" in out.keys(), "Ensemble forward failed to write keys"
assert out["spork"].shape == torch.Size(
[4, 5, 1]
), "Ensemble forward failed to expand input"


class TestTDModule:
def test_multiple_output(self):
class MultiHeadLinear(nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@
DistributionalQValueHook,
DistributionalQValueModule,
EGreedyWrapper,
EnsembleModule,
LMHeadActorValueOperator,
LSTMModule,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
QValueHook,
QValueModule,
Reduce,
SafeModule,
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ValueOperator,
)
from .common import SafeModule, VmapModule
from .ensemble import EnsembleModule, Reduce
from .exploration import (
AdditiveGaussianWrapper,
EGreedyWrapper,
Expand Down
137 changes: 137 additions & 0 deletions torchrl/modules/tensordict_module/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Callable

import torch
from tensordict import TensorDict
from tensordict.nn import make_functional, TensorDictModuleBase
from torch import nn


class Reduce(TensorDictModuleBase):
"""A reduction operator that reduces across the ensemble dimension.

Args:
in_keys (list[str]): The input keys to reduce, this must be length one.
out_keys (list[str]): The output keys to reduce, this must be length one.
reduce_function (Callable): The function to use to reduce across the ensemble dimension.

Examples:
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from torchrl.modules import EnsembleModule, Reduce
>>> from tensordict import TensorDict
>>> module = TensorDictModule(nn.Linear(2, 3), in_keys=["bork"], out_keys=["dork"])
>>> m0 = EnsembleModule(module, num_copies=2)
>>> m1 = Reduce(["dork"], ["spork"])
>>> seq = TensorDictSequential(m0, m1)
>>> td = TensorDict({"bork": torch.randn(5, 2)}, batch_size=[5])
>>> seq(td)
TensorDict(
fields={
bork: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
dork: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
spork: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
>>> m0(td).shape
torch.Size([2, 5])
>>> seq(td).shape
torch.Size([5])
"""
def __init__(
self,
in_keys: list[str],
out_keys: list[str],
reduce_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.min(dim=0).values,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best way to approach this is

Suggested change
reduce_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.min(dim=0).values,
reduce_function: Callable[[torch.Tensor], torch.Tensor] = None,

then

if reduce_function is None:
    reduce_function = lambda x: x.min(dim=0).values

But do we really want to have "min" as a default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Enum to select which op?

REDUCE_OP.min
REDUCE_OP.max
REDUCE_OP.sum
REDUCE_OP.any
REDUCE_OP.all

And then we can ask the user along which dim the op should be done.
That dim would also be the dim along which we do the tensordict indexing.

):
super().__init__()
self.in_keys = in_keys
self.out_keys = out_keys
self.reduce_function = reduce_function
assert (
len(in_keys) == len(out_keys) == 1
), "Reduce only supports one input and one output"
Comment on lines +52 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a ValueError


def forward(self, tensordict):
reduced = self.reduce_function(tensordict.get(self.in_keys[0]))
# We assume that all inputs are identical across the ensemble dim
# except for the input/output keys
tensordict_reduced = tensordict[0]
tensordict_reduced.set(self.out_keys[0], reduced)
return tensordict_reduced


class EnsembleModule(TensorDictModuleBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate with tensordict

"""Module that wraps a module and repeats it to form an ensemble.

Args:
module (nn.Module): The nn.module to duplicate and wrap.
num_copies (int): The number of copies of module to make.
parameter_init_function (Callable): A function that takes a module copy and initializes its parameters.
expand_input (bool): Whether to expand the input TensorDict to match the number of copies. This should be
True unless you are chaining ensemble modules together, e.g. EnsembleModule(cnn) -> EnsembleModule(mlp).
If False, EnsembleModule(mlp) will expected the previous module(s) to have already expanded the input.

Examples:
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 10]),
device=None,
is_shared=False)

>>> import torch
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> module = TensorDictModule(torch.nn.Linear(2,3), in_keys=['bork'], out_keys=['dork'])
>>> next_module = TensorDictModule(torch.nn.Linear(3,1), in_keys=['dork'], out_keys=['spork'])
>>> e0 = EnsembleModule(module, num_copies=4, expand_input=True)
>>> e1 = EnsembleModule(next_module, num_copies=4, expand_input=False)
>>> seq = TensorDictSequential(e0, e1)
>>> data = TensorDict({'bork': torch.randn(5,2)}, batch_size=[5])
>>> seq(data)
TensorDict(
fields={
bork: Tensor(shape=torch.Size([4, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
dork: Tensor(shape=torch.Size([4, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
spork: Tensor(shape=torch.Size([4, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4, 5]),
device=None,
is_shared=False)
"""

def __init__(
self,
module: TensorDictModuleBase,
num_copies: int,
expand_input: bool = True,
):
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
self.module = module
params_td = make_functional(module).expand(num_copies).to_tensordict()
module.reset_parameters(params_td)

self.params_td = params_td
self.params = nn.ParameterList(list(self.params_td.values(True, True)))
if expand_input:
self.vmapped_forward = torch.vmap(self.module, (None, 0))
else:
self.vmapped_forward = torch.vmap(self.module, 0)

def forward(self, tensordict: TensorDict) -> TensorDict:
return self.vmapped_forward(tensordict, self.params_td)