diff --git a/examples/arena/run_arena.py b/examples/arena/run_arena.py index 061176ab..71bd593f 100644 --- a/examples/arena/run_arena.py +++ b/examples/arena/run_arena.py @@ -20,8 +20,13 @@ from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner -def run_arena(): - render = True +def run_arena( + render: bool = False, + parallel: bool = True, + seed=0, + total_games: int = 10, + max_game_onetime: int = 5, +): env_wrappers = [RecordWinner] if render: from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender @@ -35,13 +40,15 @@ def run_arena(): arena.reset( agents={"agent1": agent1, "agent2": agent2}, - total_games=10, - max_game_onetime=5, + total_games=total_games, + max_game_onetime=max_game_onetime, + seed=seed, ) - result = arena.run(parallel=True) - print(result) + result = arena.run(parallel=parallel) arena.close() + print(result) + return result if __name__ == "__main__": - run_arena() + run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10) diff --git a/examples/arena/test_reproducibility.py b/examples/arena/test_reproducibility.py new file mode 100644 index 00000000..dbfb1c2f --- /dev/null +++ b/examples/arena/test_reproducibility.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from run_arena import run_arena + + +def test_seed(seed: int): + test_time = 5 + pre_result = None + for parallel in [False, True]: + for i in range(test_time): + result = run_arena(seed=seed, parallel=parallel, total_games=20) + if pre_result is not None: + assert pre_result == result, f"parallel={parallel}, seed={seed}" + pre_result = result + + +if __name__ == "__main__": + test_seed(0) diff --git a/openrl/arena/base_arena.py b/openrl/arena/base_arena.py index 9de1d7d6..5f8217f2 100644 --- a/openrl/arena/base_arena.py +++ b/openrl/arena/base_arena.py @@ -39,13 +39,16 @@ def __init__(self, env_fn: Callable, dispatch_func: Optional[Callable] = None): self.max_game_onetime = None self.agents = None self.game: Optional[BaseGame] = None + self.seed = None def reset( self, agents: Dict[str, BaseAgent], total_games: int, max_game_onetime: int = 5, + seed: int = 0, ): + self.seed = seed if self.pbar: self.pbar.refresh() self.pbar.close() @@ -54,7 +57,7 @@ def reset( self.max_game_onetime = max_game_onetime self.agents = agents assert isinstance(self.game, BaseGame) - self.game.reset(dispatch_func=self.dispatch_func) + self.game.reset(seed=seed, dispatch_func=self.dispatch_func) def close(self): if self.pbar: @@ -67,9 +70,12 @@ def _run_parallel(self): ) as executor: futures = [ executor.submit( - self.game.run, CloudpickleWrapper(self.env_fn), self.agents + self.game.run, + self.seed + run_index, + CloudpickleWrapper(self.env_fn), + self.agents, ) - for _ in range(self.total_games) + for run_index in range(self.total_games) ] for future in as_completed(futures): result = future.result() @@ -77,12 +83,13 @@ def _run_parallel(self): self.pbar.update(1) def _run_serial(self): - for _ in range(self.total_games): - result = self.game.run(self.env_fn, self.agents) + for run_index in range(self.total_games): + result = self.game.run(self.seed + run_index, self.env_fn, self.agents) self._deal_result(result) self.pbar.update(1) def run(self, parallel: bool = True) -> Dict[str, Any]: + assert self.seed is not None, "Please call reset() to set seed first." if parallel: self._run_parallel() else: diff --git a/openrl/arena/games/base_game.py b/openrl/arena/games/base_game.py index 41c01080..8c6acdbc 100644 --- a/openrl/arena/games/base_game.py +++ b/openrl/arena/games/base_game.py @@ -18,35 +18,53 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, List, Optional, Tuple +import numpy as np +from gymnasium.utils import seeding + from openrl.arena.agents.base_agent import BaseAgent from openrl.selfplay.opponents.base_opponent import BaseOpponent class BaseGame(ABC): + _np_random: Optional[np.random.Generator] = None + def __init__(self): self.dispatch_func = None + self.seed = None - def reset(self, dispatch_func: Optional[Callable] = None): - if dispatch_func is not None: - self.dispatch_func = dispatch_func - else: - self.dispatch_func = self.default_dispatch_func + def reset(self, seed: int, dispatch_func: Optional[Callable] = None): + self.seed = seed + self._np_random, seed = seeding.np_random(seed) + if self.dispatch_func is None: + if dispatch_func is not None: + self.dispatch_func = dispatch_func + else: + self.dispatch_func = self.default_dispatch_func def dispatch_agent_to_player( self, players: List[str], agents: Dict[str, BaseAgent] ) -> Tuple[Dict[str, BaseOpponent], Dict[str, str]]: + assert self._np_random is not None player2agent = {} - player2agent_name = self.dispatch_func(players, list(agents.keys())) + player2agent_name = self.dispatch_func( + self._np_random, players, list(agents.keys()) + ) for player in players: player2agent[player] = agents[player2agent_name[player]].new_agent() return player2agent, player2agent_name @staticmethod def default_dispatch_func( - players: List[str], agent_names: List[str] + np_random: np.random.Generator, + players: List[str], + agent_names: List[str], ) -> Dict[str, str]: raise NotImplementedError + def run(self, seed: int, env_fn: Callable, agents: List[BaseAgent]): + self.reset(seed=seed) + return self._run(env_fn, agents) + @abstractmethod - def run(self, env_fn, agents): + def _run(self, env_fn: Callable, agents: List[BaseAgent]): raise NotImplementedError diff --git a/openrl/arena/games/two_player_game.py b/openrl/arena/games/two_player_game.py index 8193efc8..5fe32fd4 100644 --- a/openrl/arena/games/two_player_game.py +++ b/openrl/arena/games/two_player_game.py @@ -16,26 +16,32 @@ """""" import random -from typing import Dict, List +from typing import Callable, Dict, List +import numpy as np + +from openrl.arena.agents.base_agent import BaseAgent from openrl.arena.games.base_game import BaseGame class TwoPlayerGame(BaseGame): @staticmethod def default_dispatch_func( - players: List[str], agent_names: List[str] + np_random: np.random.Generator, + players: List[str], + agent_names: List[str], ) -> Dict[str, str]: assert len(players) == len( agent_names ), "The number of players must be equal to the number of agents." assert len(players) == 2, "The number of players must be equal to 2." - random.shuffle(agent_names) + np_random.shuffle(agent_names) return dict(zip(players, agent_names)) - def run(self, env_fn, agents): + def _run(self, env_fn: Callable, agents: List[BaseAgent]): env = env_fn() - env.reset() + env.reset(seed=self.seed) + player2agent, player2agent_name = self.dispatch_agent_to_player( env.agents, agents ) diff --git a/openrl/envs/PettingZoo/__init__.py b/openrl/envs/PettingZoo/__init__.py index 27804e66..e5111afc 100644 --- a/openrl/envs/PettingZoo/__init__.py +++ b/openrl/envs/PettingZoo/__init__.py @@ -19,6 +19,7 @@ from typing import List, Optional, Union from openrl.envs.common import build_envs +from openrl.envs.wrappers.pettingzoo_wrappers import SeedEnv def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs): @@ -37,7 +38,7 @@ def make_PettingZoo_env( **kwargs, ): env_num = 1 - env_wrappers = [] + env_wrappers = [SeedEnv] env_wrappers += copy.copy(kwargs.pop("env_wrappers", [])) env_fns = build_envs( make=PettingZoo_make, @@ -62,7 +63,7 @@ def make_PettingZoo_envs( Single2MultiAgentWrapper, ) - env_wrappers = copy.copy(kwargs.pop("opponent_wrappers", [])) + env_wrappers = copy.copy(kwargs.pop("opponent_wrappers", [SeedEnv])) env_wrappers += [ Single2MultiAgentWrapper, RemoveTruncated, diff --git a/openrl/envs/wrappers/pettingzoo_wrappers.py b/openrl/envs/wrappers/pettingzoo_wrappers.py index 3d6303e0..687384a4 100644 --- a/openrl/envs/wrappers/pettingzoo_wrappers.py +++ b/openrl/envs/wrappers/pettingzoo_wrappers.py @@ -15,12 +15,22 @@ # limitations under the License. """""" - +from typing import Optional from pettingzoo.utils.env import ActionType, AECEnv from pettingzoo.utils.wrappers import BaseWrapper +class SeedEnv(BaseWrapper): + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): + super().reset(seed=seed, options=options) + + for i, space in enumerate( + list(self.action_spaces.values()) + list(self.observation_spaces.values()) + ): + space.seed(seed + i * 7891) + + class RecordWinner(BaseWrapper): def __init__(self, env: AECEnv): super().__init__(env)