Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 11, 2024
1 parent 69a672c commit 04c327d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 8 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,13 @@ def full_action_spec(self) -> Composite:
domain=continuous), device=cpu, shape=torch.Size([]))
"""
return self.input_spec["full_action_spec"]
full_action_spec = self.input_spec.get("full_action_spec", None)
if full_action_spec is None:
full_action_spec = CompositeSpec(shape=self.batch_size, device=self.device)
self.input_spec.unlock_()
self.input_spec["full_action_spec"] = full_action_spec
self.input_spec.lock_()
return full_action_spec

@full_action_spec.setter
def full_action_spec(self, spec: Composite) -> None:
Expand Down Expand Up @@ -1334,7 +1340,7 @@ def observation_spec(self) -> Composite:
domain=continuous), device=cpu, shape=torch.Size([]))
"""
observation_spec = self.output_spec["full_observation_spec"]
observation_spec = self.output_spec.get("full_observation_spec", default=None)
if observation_spec is None:
observation_spec = Composite(shape=self.batch_size, device=self.device)
self.output_spec.unlock_()
Expand Down
4 changes: 1 addition & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4767,9 +4767,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
new_state_spec = self.transform_observation_spec(
input_spec["full_state_spec"]
)
new_state_spec = self.transform_observation_spec(input_spec["full_state_spec"])
for action_key in list(input_spec["full_action_spec"].keys(True, True)):
if action_key in new_state_spec.keys(True, True):
input_spec["full_action_spec", action_key] = new_state_spec[action_key]
Expand Down

0 comments on commit 04c327d

Please sign in to comment.