From 89da3df00ccca3b2fafa042c2293ad0699673f07 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 5 May 2023 21:33:10 +0100 Subject: [PATCH] init --- test/test_transforms.py | 60 +++++++++++++++++++++------ torchrl/envs/transforms/transforms.py | 18 +++++--- 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 6d0e2a6648c..ffc6884fd0b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -415,10 +415,14 @@ def test_transform_model(self, dim, N, padding): torch.manual_seed(10) envbase.set_seed(10) tdbase = envbase.rollout(10) + tdbase0 = tdbase.clone() model = nn.Sequential(cat_frames2, nn.Identity()) model(tdbase) assert (td == tdbase).all() + with pytest.raises(ValueError, match="The last dimension of the tensordict"): + tdbase0.names = None + model(tdbase0) @pytest.mark.parametrize("dim", [-2, -1]) @pytest.mark.parametrize("N", [3, 4]) @@ -3757,8 +3761,18 @@ def test_transform_no_env( t = RewardSum() reward = torch.randn(10) td = TensorDict({("next", "reward"): reward}, []) - with pytest.raises(NotImplementedError): + with pytest.raises( + ValueError, match="At least one dimension of the tensordict" + ): t(td) + td.batch_size = [10] + td.names = ["time"] + with pytest.raises(KeyError): + t(td) + t = RewardSum( + in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")] + ) + t(td) def test_transform_compose( self, @@ -3766,8 +3780,18 @@ def test_transform_compose( t = Compose(RewardSum()) reward = torch.randn(10) td = TensorDict({("next", "reward"): reward}, []) - with pytest.raises(NotImplementedError): + with pytest.raises( + ValueError, match="At least one dimension of the tensordict" + ): t(td) + td.batch_size = [10] + td.names = ["time"] + with pytest.raises(KeyError): + t(td) + t = RewardSum( + in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")] + ) + t(td) @pytest.mark.skipif(not _has_gym, reason="No Gym") def test_transform_env( @@ -3789,24 +3813,34 @@ def test_transform_env( def test_transform_model( self, ): - t = RewardSum() + t = RewardSum( + in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")] + ) model = nn.Sequential(t, nn.Identity()) - reward = torch.randn(10) - td = TensorDict({("next", "reward"): reward}, []) - with pytest.raises(NotImplementedError): - model(td) + env = TransformedEnv(ContinuousActionVecMockEnv(), RewardSum()) + data = env.rollout(10) + data_exclude = data.exclude(("next", "episode_reward")) + model(data_exclude) + assert ( + data_exclude["next", "episode_reward"] == data["next", "episode_reward"] + ).all() def test_transform_rb( self, ): - t = RewardSum() + t = RewardSum( + in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")] + ) rb = ReplayBuffer(storage=LazyTensorStorage(10)) - reward = torch.randn(10) - td = TensorDict({("next", "reward"): reward}, []).expand(10) + env = TransformedEnv(ContinuousActionVecMockEnv(), RewardSum()) + data = env.rollout(10) + data_exclude = data.exclude(("next", "episode_reward")) rb.append_transform(t) - rb.extend(td) - with pytest.raises(NotImplementedError): - _ = rb.sample(2) + rb.add(data_exclude) + sample = rb.sample(1).squeeze(0) + assert ( + sample["next", "episode_reward"] == data["next", "episode_reward"] + ).all() @pytest.mark.parametrize( "keys", diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5eece6b11dd..e798e4c0a9f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2072,9 +2072,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension - if not tensordict.ndim or ( - tensordict.names[-1] is not None and tensordict.names[-1] != "time" - ): + if not tensordict.ndim or tensordict.names[-1] != "time": raise ValueError( "The last dimension of the tensordict must be marked as 'time'." ) @@ -3368,9 +3366,17 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - raise NotImplementedError( - FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__) - ) + time_dim = [i for i, name in enumerate(tensordict.names) if name == "time"] + if not time_dim: + raise ValueError( + "At least one dimension of the tensordict must be named 'time' in offline mode" + ) + time_dim = time_dim[0] - 1 + for in_key, out_key in zip(self.in_keys, self.out_keys): + reward = tensordict.get(in_key) + cumsum = reward.cumsum(time_dim) + tensordict.set(out_key, cumsum) + return tensordict class StepCounter(Transform):