Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dispatch for SAC loss module #1223

Merged
merged 3 commits into from
Jun 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 76 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3213,13 +3213,20 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
class TestA2C(LossModuleTestBase):
seed = 0

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_actor(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
observation_key="observation",
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
module = SafeModule(net, in_keys=[observation_key], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
in_keys=["loc", "scale"],
Expand All @@ -3229,12 +3236,18 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
return actor.to(device)

def _create_mock_value(
self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
out_keys=None,
observation_key="observation",
):
module = nn.Linear(obs_dim, 1)
value = ValueOperator(
module=module,
in_keys=["observation"],
in_keys=[observation_key],
out_keys=out_keys,
)
return value.to(device)
Expand All @@ -3248,6 +3261,9 @@ def _create_seq_mock_data_a2c(
atoms=None,
device="cpu",
action_key="action",
observation_key="observation",
reward_key="reward",
done_key="done",
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
Expand All @@ -3267,11 +3283,11 @@ def _create_seq_mock_data_a2c(
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
observation_key: obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"done": done,
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
observation_key: next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
done_key: done,
reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0),
Expand Down Expand Up @@ -3443,6 +3459,8 @@ def test_a2c_tensordict_keys(self, td_est):
"value_target": "value_target",
"value": "state_value",
"action": "action",
"reward": "reward",
"done": "done",
}

self.tensordict_keys_test(
Expand All @@ -3459,6 +3477,8 @@ def test_a2c_tensordict_keys(self, td_est):
"advantage": ("advantage", "advantage_test"),
"value_target": ("value_target", "value_target_test"),
"value": ("value", "value_state_test"),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

Expand All @@ -3471,8 +3491,15 @@ def test_a2c_tensordict_keys_run(self, device):
value_target_key = "value_target_test"
value_key = "state_value_test"
action_key = "action_test"
reward_key = "reward_test"
done_key = ("done", "test")

td = self._create_seq_mock_data_a2c(device=device, action_key=action_key)
td = self._create_seq_mock_data_a2c(
device=device,
action_key=action_key,
reward_key=reward_key,
done_key=done_key,
)

actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device, out_keys=[value_key])
Expand All @@ -3486,13 +3513,17 @@ def test_a2c_tensordict_keys_run(self, device):
advantage=advantage_key,
value_target=value_target_key,
value=value_key,
reward=reward_key,
done=done_key,
)
loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn.set_keys(
advantage=advantage_key,
value_target=value_target_key,
value=value_key,
action=action_key,
reward=reward_key,
done=done_key,
)

advantage(td)
Expand Down Expand Up @@ -3525,6 +3556,42 @@ def test_a2c_tensordict_keys_run(self, device):
# test reset
loss_fn.reset()

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_key):
torch.manual_seed(self.seed)

actor = self._create_mock_actor(observation_key=observation_key)
value = self._create_mock_value(observation_key=observation_key)
td = self._create_seq_mock_data_a2c(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
)

loss = A2CLoss(actor, value)
loss.set_keys(action=action_key, reward=reward_key, done=done_key)

kwargs = {
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": td.get(("next", done_key)),
action_key: td.get(action_key),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

loss_val = loss(**kwargs)
loss_val_td = loss(td)

torch.testing.assert_close(loss_val_td.get("loss_objective"), loss_val[0])
torch.testing.assert_close(loss_val_td.get("loss_critic"), loss_val[1])
# don't test entropy and loss_entropy, since they depend on a random sample
# from distribution
assert len(loss_val) == 4


class TestReinforce(LossModuleTestBase):
@pytest.mark.parametrize("delay_value", [True, False])
Expand Down
124 changes: 120 additions & 4 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Tuple

import torch
from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import distributions as d
Expand Down Expand Up @@ -67,6 +67,88 @@ class A2CLoss(LossModule):
The default is :class:`~torchrl.objectives.value.GAE` with hyperparameters
dictated by :func:`~torchrl.objectives.utils.default_value_kwargs`.

Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.a2c import A2CLoss
>>> from tensordict.tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> loss = A2CLoss(actor, value, loss_critic_type="l2")
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... }, batch)
>>> loss(data)
TensorDict(
fields={
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)

This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic.
The return value is a tuple of tensors in the following order:
``["loss_objective"]``
+ ``["loss_critic"]`` if critic_coef is not None
+ ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None

Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.a2c import A2CLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> loss = A2CLoss(actor, value, loss_critic_type="l2")
>>> batch = [2, ]
>>> loss_val = loss(
... observation = torch.randn(*batch, n_obs),
... action = spec.rand(batch),
... next_done = torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward = torch.randn(*batch, 1))
>>> loss_val
(tensor(1.7593, grad_fn=<MeanBackward0>), tensor(0.2344, grad_fn=<MeanBackward0>), tensor(1.5480), tensor(-0.0155, grad_fn=<MulBackward0>))
"""

@dataclass
Expand All @@ -85,12 +167,19 @@ class _AcceptedKeys:
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.GAE
Expand Down Expand Up @@ -141,9 +230,11 @@ def __init__(
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self._tensor_keys.advantage,
value_target=self._tensor_keys.value_target,
value=self._tensor_keys.value,
advantage=self.tensor_keys.advantage,
value_target=self.tensor_keys.value_target,
value=self.tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)

def reset(self) -> None:
Expand Down Expand Up @@ -198,6 +289,29 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
)
return self.critic_coef * loss_value

@property
def in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
]
keys.extend(self.actor.in_keys)
if self.critic_coef:
keys.extend(self.critic.in_keys)
return list(set(keys))

@property
def out_keys(self):
outs = ["loss_objective"]
if self.critic_coef:
outs.append("loss_critic")
if self.entropy_bonus:
outs.append("entropy")
outs.append("loss_entropy")
return outs

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
Expand Down Expand Up @@ -243,5 +357,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
"advantage": self.tensor_keys.advantage,
"value": self.tensor_keys.value,
"value_target": self.tensor_keys.value_target,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)