Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow sequential transforms to work offline #1136

Merged
merged 1 commit into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,14 @@ def test_transform_model(self, dim, N, padding):
torch.manual_seed(10)
envbase.set_seed(10)
tdbase = envbase.rollout(10)
tdbase0 = tdbase.clone()

model = nn.Sequential(cat_frames2, nn.Identity())
model(tdbase)
assert (td == tdbase).all()
with pytest.raises(ValueError, match="The last dimension of the tensordict"):
tdbase0.names = None
model(tdbase0)

@pytest.mark.parametrize("dim", [-2, -1])
@pytest.mark.parametrize("N", [3, 4])
Expand Down Expand Up @@ -3757,17 +3761,37 @@ def test_transform_no_env(
t = RewardSum()
reward = torch.randn(10)
td = TensorDict({("next", "reward"): reward}, [])
with pytest.raises(NotImplementedError):
with pytest.raises(
ValueError, match="At least one dimension of the tensordict"
):
t(td)
td.batch_size = [10]
td.names = ["time"]
with pytest.raises(KeyError):
t(td)
t = RewardSum(
in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")]
)
t(td)

def test_transform_compose(
self,
):
t = Compose(RewardSum())
reward = torch.randn(10)
td = TensorDict({("next", "reward"): reward}, [])
with pytest.raises(NotImplementedError):
with pytest.raises(
ValueError, match="At least one dimension of the tensordict"
):
t(td)
td.batch_size = [10]
td.names = ["time"]
with pytest.raises(KeyError):
t(td)
t = RewardSum(
in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")]
)
t(td)

@pytest.mark.skipif(not _has_gym, reason="No Gym")
def test_transform_env(
Expand All @@ -3789,24 +3813,34 @@ def test_transform_env(
def test_transform_model(
self,
):
t = RewardSum()
t = RewardSum(
in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")]
)
model = nn.Sequential(t, nn.Identity())
reward = torch.randn(10)
td = TensorDict({("next", "reward"): reward}, [])
with pytest.raises(NotImplementedError):
model(td)
env = TransformedEnv(ContinuousActionVecMockEnv(), RewardSum())
data = env.rollout(10)
data_exclude = data.exclude(("next", "episode_reward"))
model(data_exclude)
assert (
data_exclude["next", "episode_reward"] == data["next", "episode_reward"]
).all()

def test_transform_rb(
self,
):
t = RewardSum()
t = RewardSum(
in_keys=[("next", "reward")], out_keys=[("next", "episode_reward")]
)
rb = ReplayBuffer(storage=LazyTensorStorage(10))
reward = torch.randn(10)
td = TensorDict({("next", "reward"): reward}, []).expand(10)
env = TransformedEnv(ContinuousActionVecMockEnv(), RewardSum())
data = env.rollout(10)
data_exclude = data.exclude(("next", "episode_reward"))
rb.append_transform(t)
rb.extend(td)
with pytest.raises(NotImplementedError):
_ = rb.sample(2)
rb.add(data_exclude)
sample = rb.sample(1).squeeze(0)
assert (
sample["next", "episode_reward"] == data["next", "episode_reward"]
).all()

@pytest.mark.parametrize(
"keys",
Expand Down
18 changes: 12 additions & 6 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,9 +2072,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# it is assumed that the last dimension of the tensordict is the time dimension
if not tensordict.ndim or (
tensordict.names[-1] is not None and tensordict.names[-1] != "time"
):
if not tensordict.ndim or tensordict.names[-1] != "time":
raise ValueError(
"The last dimension of the tensordict must be marked as 'time'."
)
Expand Down Expand Up @@ -3368,9 +3366,17 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
return observation_spec

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise NotImplementedError(
FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__)
)
time_dim = [i for i, name in enumerate(tensordict.names) if name == "time"]
if not time_dim:
raise ValueError(
"At least one dimension of the tensordict must be named 'time' in offline mode"
)
time_dim = time_dim[0] - 1
for in_key, out_key in zip(self.in_keys, self.out_keys):
reward = tensordict.get(in_key)
cumsum = reward.cumsum(time_dim)
tensordict.set(out_key, cumsum)
return tensordict


class StepCounter(Transform):
Expand Down