From f330a00eb40f390eac5dd722aeede220982a5cc8 Mon Sep 17 00:00:00 2001 From: Markus28 Date: Mon, 26 Sep 2022 20:13:01 +0200 Subject: [PATCH 1/3] Added support for new PettingZoo API --- test/pettingzoo/test_pistonball.py | 2 -- test/pettingzoo/test_tic_tac_toe.py | 2 -- tianshou/env/pettingzoo_env.py | 15 ++++++++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 9e9c8bdb2..4a6c59655 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -1,10 +1,8 @@ import pprint -import pytest from pistonball import get_args, train_agent, watch -@pytest.mark.skip(reason="TODO(Markus28): fix later") def test_piston_ball(args=get_args()): if args.watch: watch(args) diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index 29b251b81..524cdb92a 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -1,10 +1,8 @@ import pprint -import pytest from tic_tac_toe import get_args, train_agent, watch -@pytest.mark.skip(reason="TODO(Markus28): fix later") def test_tic_tac_toe(args=get_args()): if args.watch: watch(args) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 1722dc563..1b9812151 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -57,7 +57,8 @@ def __init__(self, env: BaseWrapper): def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: self.env.reset(*args, **kwargs) - observation, _, _, info = self.env.last(self) + last_return = self.env.last(self) + observation, info = last_return[0], last_return[-1] if isinstance(observation, dict) and 'action_mask' in observation: observation_dict = { 'agent_id': self.env.agent_selection, @@ -83,9 +84,13 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: else: return observation_dict - def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: + def step( + self, action: Any + ) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool, + Dict]]: self.env.step(action) - observation, rew, done, info = self.env.last() + last_return = self.env.last() + observation = last_return[0] if isinstance(observation, dict) and 'action_mask' in observation: obs = { 'agent_id': self.env.agent_selection, @@ -105,7 +110,7 @@ def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: for agent_id, reward in self.env.rewards.items(): self.rewards[self.agent_idx[agent_id]] = reward - return obs, self.rewards, done, info + return (obs, self.rewards, *last_return[2:]) # type: ignore def close(self) -> None: self.env.close() @@ -113,7 +118,7 @@ def close(self) -> None: def seed(self, seed: Any = None) -> None: try: self.env.seed(seed) - except NotImplementedError: + except (NotImplementedError, AttributeError): self.env.reset(seed=seed) def render(self, mode: str = "human") -> Any: From 0d084540498bc401beb3a4de80135f83c1784781 Mon Sep 17 00:00:00 2001 From: Markus28 Date: Sat, 1 Oct 2022 21:31:08 +0200 Subject: [PATCH 2/3] Added deprecation warnings --- tianshou/env/pettingzoo_env.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 1b9812151..92ed55ab9 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,10 +1,20 @@ +import warnings from abc import ABC from typing import Any, Dict, List, Tuple, Union import gym.spaces +import pettingzoo +from packaging import version from pettingzoo.utils.env import AECEnv from pettingzoo.utils.wrappers import BaseWrapper +if version.parse(pettingzoo.__version__) < version.parse("1.21.0"): + warnings.warn( + f"You are using PettingZoo {pettingzoo.__version__}. " + f"Future tianshou versions may not support PettingZoo<1.21.0. " + f"Consider upgrading your PettingZoo version.", DeprecationWarning + ) + class PettingZooEnv(AECEnv, ABC): """The interface for petting zoo environments. @@ -58,6 +68,15 @@ def __init__(self, env: BaseWrapper): def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: self.env.reset(*args, **kwargs) last_return = self.env.last(self) + + if len(last_return) == 4: + warnings.warn( + "The PettingZoo environment is using the old step API. " + "This API may not be supported in future versions of tianshou. " + "We recommend that you update the environment code or apply a " + "compatibility wrapper.", DeprecationWarning + ) + observation, info = last_return[0], last_return[-1] if isinstance(observation, dict) and 'action_mask' in observation: observation_dict = { From 0c124f8b03618832b6fb679f243c574cd0c6af34 Mon Sep 17 00:00:00 2001 From: Markus28 Date: Sat, 1 Oct 2022 21:37:41 +0200 Subject: [PATCH 3/3] Added comments --- tianshou/env/pettingzoo_env.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 92ed55ab9..d1ab131de 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -67,6 +67,9 @@ def __init__(self, env: BaseWrapper): def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: self.env.reset(*args, **kwargs) + + # Here, we do not label the return values explicitly to keep compatibility with + # old step API. TODO: Change once PettingZoo>=1.21.0 is required last_return = self.env.last(self) if len(last_return) == 4: @@ -108,6 +111,9 @@ def step( ) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool, Dict]]: self.env.step(action) + + # Here, we do not label the return values explicitly to keep compatibility with + # old step API. TODO: Change once PettingZoo>=1.21.0 is required last_return = self.env.last() observation = last_return[0] if isinstance(observation, dict) and 'action_mask' in observation: