diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index b147a005173..9853e8d516d 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -390,10 +390,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): n=2, shape=group_action_spec["action"].shape if not self.categorical_actions - else ( - *group_action_spec["action"].shape, - group_action_spec["action"].space.n, - ), + else group_action_spec["action"].to_one_hot_spec().shape, dtype=torch.bool, device=self.device, ) @@ -494,7 +491,7 @@ def _init_env(self): n=2, shape=group_action_spec.shape if not self.categorical_actions - else (*group_action_spec.shape, group_action_spec.space.n), + else group_action_spec.to_one_hot_spec().shape, dtype=torch.bool, device=self.device, )