Skip to content

Commit

Permalink
[Feature] Dispatch for DDPG loss module (#1215)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 4, 2023
1 parent e955cfc commit 331f677
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 21 deletions.
62 changes: 54 additions & 8 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,14 @@ def _create_mock_distributional_actor(
raise NotImplementedError

def _create_mock_data_ddpg(
self, batch=8, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
batch=8,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
reward_key="reward",
done_key="done",
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
Expand All @@ -712,8 +719,8 @@ def _create_mock_data_ddpg(
"observation": obs,
"next": {
"observation": next_obs,
"done": done,
"reward": reward,
done_key: done,
reward_key: reward,
},
"action": action,
},
Expand All @@ -722,7 +729,15 @@ def _create_mock_data_ddpg(
return td

def _create_seq_mock_data_ddpg(
self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
batch=8,
T=4,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
reward_key="reward",
done_key="done",
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
Expand All @@ -743,8 +758,8 @@ def _create_seq_mock_data_ddpg(
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"done": done,
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
done_key: done,
reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
Expand Down Expand Up @@ -901,6 +916,8 @@ def test_ddpg_tensordict_keys(self, td_est):
)

default_keys = {
"reward": "reward",
"done": "done",
"state_action_value": "state_action_value",
"priority": "td_error",
}
Expand All @@ -917,7 +934,11 @@ def test_ddpg_tensordict_keys(self, td_est):
value,
loss_function="l2",
)
key_mapping = {"state_action_value": ("value", "state_action_value_test")}
key_mapping = {
"state_action_value": ("value", "state_action_value_test"),
"reward": ("reward", "reward2"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize(
Expand All @@ -930,11 +951,15 @@ def test_ddpg_tensordict_run(self, td_est):
tensor_keys = {
"state_action_value": "state_action_value_test",
"priority": "td_error_test",
"reward": "reward_test",
"done": ("done", "test"),
}

actor = self._create_mock_actor()
value = self._create_mock_value(out_keys=[tensor_keys["state_action_value"]])
td = self._create_mock_data_ddpg()
td = self._create_mock_data_ddpg(
reward_key="reward_test", done_key=("done", "test")
)
loss_fn = DDPGLoss(
actor,
value,
Expand All @@ -948,6 +973,27 @@ def test_ddpg_tensordict_run(self, td_est):
with _check_td_steady(td):
_ = loss_fn(td)

def test_ddpg_notensordict(self):
torch.manual_seed(self.seed)
actor = self._create_mock_actor()
value = self._create_mock_value()
td = self._create_mock_data_ddpg()
loss = DDPGLoss(actor, value)
loss.make_value_estimator(ValueEstimators.TD1)

kwargs = {
"observation": td.get("observation"),
"next_reward": td.get(("next", "reward")),
"next_done": td.get(("next", "done")),
"action": td.get("action"),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

loss_val_td = loss(td)
loss_val = loss(**kwargs)
for i, key in enumerate(loss_val_td.keys()):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down
137 changes: 124 additions & 13 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Tuple

import torch
from tensordict.nn import make_functional, repopulate_module, TensorDictModule
from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey

Expand All @@ -38,6 +38,84 @@ class DDPGLoss(LossModule):
data collection. Default is ``False``.
delay_value (bool, optional): whether to separate the target value networks from the value networks used for
data collection. Default is ``True``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.ddpg import DDPGLoss
>>> from tensordict.tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> value = ValueOperator(
... module=module,
... in_keys=["observation", "action"])
>>> loss = DDPGLoss(actor, value)
>>> batch = [2, ]
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": spec.rand(batch),
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... }, batch)
>>> loss(data)
TensorDict(
fields={
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["next_reward", "next_done"]`` + in_keys of the actor_network and value_network.
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]``
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.ddpg import DDPGLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> value = ValueOperator(
... module=module,
... in_keys=["observation", "action"])
>>> loss = DDPGLoss(actor, value)
>>> loss_val = loss(
... observation=torch.randn(n_obs),
... action=spec.rand(),
... next_done=torch.zeros(1, dtype=torch.bool),
... next_reward=torch.randn(1))
>>> loss_val
(tensor(-0.8247, grad_fn=<MeanBackward0>), tensor(1.3344, grad_fn=<MeanBackward0>), tensor(0.6193), tensor(1.7744), tensor(0.6193), tensor(1.7744))
"""

@dataclass
Expand All @@ -49,14 +127,22 @@ class _AcceptedKeys:
Attributes:
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Will be used for the underlying
state action value is expected. Will be used for the underlying
value estimator as value key. Defaults to ``"state_action_value"``.
priority (NestedKey): The input tensordict key where the target
priority is written to. Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

state_action_value: NestedKey = "state_action_value"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.TD0
Expand Down Expand Up @@ -107,35 +193,56 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.state_action_value,
reward=self._tensor_keys.reward,
done=self._tensor_keys.done,
)

def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
@property
def in_keys(self):
keys = [
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
]
keys += self.value_network.in_keys
keys += self.actor_in_keys
keys = list(set(keys))
return keys

@dispatch(
dest=[
"loss_actor",
"loss_value",
"pred_value",
"target_value",
"pred_value_max",
"target_value_max",
]
)
def forward(self, tensordict: TensorDictBase) -> TensorDict:
"""Computes the DDPG losses given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
a priority to items in the tensordict.
Args:
input_tensordict (TensorDictBase): a tensordict with keys ["done", "reward"] and the in_keys of the actor
tensordict (TensorDictBase): a tensordict with keys ["done", "reward"] and the in_keys of the actor
and value networks.
Returns:
a tuple of 2 tensors containing the DDPG loss.
"""
loss_value, td_error, pred_val, target_value = self._loss_value(
input_tensordict,
)
loss_value, td_error, pred_val, target_value = self._loss_value(tensordict)
td_error = td_error.detach()
td_error = td_error.unsqueeze(input_tensordict.ndimension())
if input_tensordict.device is not None:
td_error = td_error.to(input_tensordict.device)
input_tensordict.set(
td_error = td_error.unsqueeze(tensordict.ndimension())
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)
tensordict.set(
self.tensor_keys.priority,
td_error,
inplace=True,
)
loss_actor = self._loss_actor(input_tensordict)
loss_actor = self._loss_actor(tensordict)
return TensorDict(
source={
"loss_actor": loss_actor.mean(),
Expand Down Expand Up @@ -220,5 +327,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
else:
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {"value": self.tensor_keys.state_action_value}
tensor_keys = {
"value": self.tensor_keys.state_action_value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

1 comment on commit 331f677

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 331f677 Previous: e955cfc Ratio
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 167.6314631668876 iter/sec (stddev: 0.00027796489327458284) 360.23891656436865 iter/sec (stddev: 0.00018627331744418822) 2.15

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.