Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Nested keys in OrnsteinUhlenbeckProcess #1305

Merged
merged 22 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 116 additions & 24 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,6 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
*self.batch_size,
1,
),
dtype=torch.int32,
device=self.device,
),
shape=self.batch_size,
Expand Down Expand Up @@ -1011,7 +1010,7 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get("action")
action = tensordict.get(self.action_key)
self.count += action.to(torch.int).to(self.device)
tensordict = TensorDict(
source={
Expand All @@ -1025,38 +1024,131 @@ def _step(
return tensordict.select().set("next", tensordict)


class NestedRewardEnv(CountingEnv):
class NestedCountingEnv(CountingEnv):
# an env with nested reward and done states
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
def __init__(
self,
max_steps: int = 5,
start_val: int = 0,
nest_obs_action: bool = True,
nest_done: bool = True,
nest_reward: bool = True,
nested_dim: int = 3,
**kwargs,
):
super().__init__(max_steps=max_steps, start_val=start_val, **kwargs)
self.observation_spec = CompositeSpec(
{("data", "states"): self.observation_spec["observation"].clone()},
shape=self.batch_size,
)
self.reward_spec = CompositeSpec(
{("data", "reward"): self.reward_spec.clone()}, shape=self.batch_size
)
self.done_spec = CompositeSpec(
{("data", "done"): self.done_spec.clone()}, shape=self.batch_size
)

self.nested_dim = nested_dim

self.nested_obs_action = nest_obs_action
self.nested_done = nest_done
self.nested_reward = nest_reward

if self.nested_obs_action:
self.observation_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"states": self.observation_spec["observation"]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
},
shape=(self.nested_dim,),
)
},
shape=self.batch_size,
)
self.action_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"action": self.action_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
)
},
shape=(self.nested_dim,),
)
},
shape=self.batch_size,
)

if self.nested_reward:
self.reward_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"reward": self.reward_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
)
},
shape=(self.nested_dim,),
)
},
shape=self.batch_size,
)

if self.nested_done:
self.done_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"done": self.done_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
)
},
shape=(self.nested_dim,),
)
},
shape=self.batch_size,
)

def _reset(self, td):
if self.nested_done and td is not None and "_reset" in td.keys():
td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool)
td = super()._reset(td)
td[self.done_key] = td["done"]
del td["done"]
td["data", "states"] = td["observation"]
del td["observation"]
td["observation"] = td["observation"].to(torch.float)
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
)
del td["done"]
if self.nested_obs_action:
td["data", "states"] = (
td["observation"]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
)
del td["observation"]
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
return td

def _step(self, td):
if self.nested_obs_action:
td["data"].batch_size = self.batch_size
td[self.action_key] = td[self.action_key].sum(-2)
td_root = super()._step(td)
td = td_root["next"]
td[self.reward_key] = td["reward"]
del td["reward"]
td[self.done_key] = td["done"]
del td["done"]
td["data", "states"] = td["observation"]
del td["observation"]
td["observation"] = td["observation"].to(torch.float)
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
)
del td["done"]
if self.nested_obs_action:
td["data", "states"] = (
td["observation"]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
)
del td["observation"]
if self.nested_reward:
td[self.reward_key] = (
td["reward"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
)
del td["reward"]
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
return td_root


Expand Down
20 changes: 10 additions & 10 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MockSerialEnv,
NestedRewardEnv,
NestedCountingEnv,
)
from packaging import version
from tensordict.nn import TensorDictModuleBase
Expand Down Expand Up @@ -1368,22 +1368,22 @@ def test_mp_collector(self, nproc):


class TestNestedSpecs:
@pytest.mark.parametrize("envclass", ["CountingEnv", "NestedRewardEnv"])
def test_nested_reward(self, envclass):
from mocking_classes import NestedRewardEnv
@pytest.mark.parametrize("envclass", ["CountingEnv", "NestedCountingEnv"])
def test_nested_env(self, envclass):
from mocking_classes import NestedCountingEnv

if envclass == "CountingEnv":
env = CountingEnv()
elif envclass == "NestedRewardEnv":
env = NestedRewardEnv()
elif envclass == "NestedCountingEnv":
env = NestedCountingEnv()
else:
raise NotImplementedError
reset = env.reset()
assert not isinstance(env.done_spec, CompositeSpec)
assert not isinstance(env.reward_spec, CompositeSpec)
assert env.done_spec == env.output_spec[("_done_spec", *env.done_key)]
assert env.reward_spec == env.output_spec[("_reward_spec", *env.reward_key)]
if envclass == "NestedRewardEnv":
if envclass == "NestedCountingEnv":
assert env.done_key == ("data", "done")
assert env.reward_key == ("data", "reward")
assert ("data", "done") in reset.keys(True)
Expand All @@ -1393,14 +1393,14 @@ def test_nested_reward(self, envclass):
assert env.reward_key not in reset.keys(True)

next_state = env.rand_step()
if envclass == "NestedRewardEnv":
if envclass == "NestedCountingEnv":
assert ("next", "data", "done") in next_state.keys(True)
assert ("next", "data", "states") in next_state.keys(True)
assert ("next", "data", "reward") in next_state.keys(True)
assert ("next", *env.done_key) in next_state.keys(True)
assert ("next", *env.reward_key) in next_state.keys(True)

check_env_specs(env)
# check_env_specs(env)


@pytest.mark.parametrize(
Expand All @@ -1420,7 +1420,7 @@ def test_nested_reward(self, envclass):
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MockSerialEnv,
NestedRewardEnv,
NestedCountingEnv,
],
)
def test_mocking_envs(envclass):
Expand Down
39 changes: 37 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import pytest
import torch
from _utils_internal import get_default_devices
from mocking_classes import ContinuousActionVecMockEnv
from mocking_classes import ContinuousActionVecMockEnv, NestedCountingEnv
from scipy.stats import ttest_1samp
from tensordict.nn import InteractionType
from tensordict.nn import InteractionType, TensorDictModule
from tensordict.tensordict import TensorDict
from torch import nn

Expand Down Expand Up @@ -180,6 +180,41 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0):
pass
return

@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
def test_nested(self, device, nested_obs_action, nested_done, seed=0):
env = NestedCountingEnv(
nest_obs_action=nested_obs_action, nest_done=nested_done
)
torch.manual_seed(seed)
# TODO Serial
env = TransformedEnv(env.to(device), InitTracker())

action_spec = env.action_spec
d_act = action_spec.shape[-1]

net = nn.LazyLinear(d_act).to(device)
policy = TensorDictModule(
net,
in_keys=[("data", "states") if nested_obs_action else "observation"],
out_keys=[env.action_key],
)

exploratory_policy = OrnsteinUhlenbeckProcessWrapper(
policy, spec=action_spec, action_key=env.action_key
)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
policy=exploratory_policy,
frames_per_batch=100,
total_frames=1000,
device=device,
)
for _ in collector:
pass
return


@pytest.mark.parametrize("device", get_default_devices())
class TestAdditiveGaussian:
Expand Down
9 changes: 6 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def _get_policy_and_device(
# we check if all the mandatory params are there
if not required_params.difference(set(next_observation)):
in_keys = [str(k) for k in sig.parameters if k in next_observation]
out_keys = ["action"]
if not hasattr(self, "env") or self.env is None:
out_keys = ["action"]
else:
out_keys = [self.env.action_key]
output = policy(**next_observation)

if isinstance(output, tuple):
Expand Down Expand Up @@ -766,7 +769,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
break

def _step_and_maybe_reset(self) -> None:
done = self._tensordict.get(("next", "done"))
done = self._tensordict.get(("next", *self.env.done_key))
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
truncated = self._tensordict.get(("next", "truncated"), None)
traj_ids = self._tensordict.get(("collector", "traj_ids"))

Expand Down Expand Up @@ -796,7 +799,7 @@ def _step_and_maybe_reset(self) -> None:
else:
self._tensordict.update(td_reset, inplace=True)

done = self._tensordict.get("done")
done = self._tensordict.get(self.env.done_key)
if done.any():
raise RuntimeError(
f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed."
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def specs(self, value: CompositeSpec):
@staticmethod
def metadata_from_env(env) -> EnvMetaData:
tensordict = env.fake_tensordict().clone()
tensordict.set("_reset", torch.zeros_like(tensordict.get("done")))
tensordict.set("_reset", torch.zeros_like(tensordict.get(env.done_key)))

specs = env.specs.to("cpu")

Expand Down
Loading