Skip to content

Commit

Permalink
add set seed to arena, and test reproducibility
Browse files Browse the repository at this point in the history
add set seed to arena, and test reproducibility
  • Loading branch information
huangshiyu13 authored Aug 11, 2023
2 parents 5d16780 + 38dd261 commit 5c78ecf
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 28 deletions.
21 changes: 14 additions & 7 deletions examples/arena/run_arena.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner


def run_arena():
render = True
def run_arena(
render: bool = False,
parallel: bool = True,
seed=0,
total_games: int = 10,
max_game_onetime: int = 5,
):
env_wrappers = [RecordWinner]
if render:
from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender
Expand All @@ -35,13 +40,15 @@ def run_arena():

arena.reset(
agents={"agent1": agent1, "agent2": agent2},
total_games=10,
max_game_onetime=5,
total_games=total_games,
max_game_onetime=max_game_onetime,
seed=seed,
)
result = arena.run(parallel=True)
print(result)
result = arena.run(parallel=parallel)
arena.close()
print(result)
return result


if __name__ == "__main__":
run_arena()
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10)
34 changes: 34 additions & 0 deletions examples/arena/test_reproducibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

from run_arena import run_arena


def test_seed(seed: int):
test_time = 5
pre_result = None
for parallel in [False, True]:
for i in range(test_time):
result = run_arena(seed=seed, parallel=parallel, total_games=20)
if pre_result is not None:
assert pre_result == result, f"parallel={parallel}, seed={seed}"
pre_result = result


if __name__ == "__main__":
test_seed(0)
17 changes: 12 additions & 5 deletions openrl/arena/base_arena.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ def __init__(self, env_fn: Callable, dispatch_func: Optional[Callable] = None):
self.max_game_onetime = None
self.agents = None
self.game: Optional[BaseGame] = None
self.seed = None

def reset(
self,
agents: Dict[str, BaseAgent],
total_games: int,
max_game_onetime: int = 5,
seed: int = 0,
):
self.seed = seed
if self.pbar:
self.pbar.refresh()
self.pbar.close()
Expand All @@ -54,7 +57,7 @@ def reset(
self.max_game_onetime = max_game_onetime
self.agents = agents
assert isinstance(self.game, BaseGame)
self.game.reset(dispatch_func=self.dispatch_func)
self.game.reset(seed=seed, dispatch_func=self.dispatch_func)

def close(self):
if self.pbar:
Expand All @@ -67,22 +70,26 @@ def _run_parallel(self):
) as executor:
futures = [
executor.submit(
self.game.run, CloudpickleWrapper(self.env_fn), self.agents
self.game.run,
self.seed + run_index,
CloudpickleWrapper(self.env_fn),
self.agents,
)
for _ in range(self.total_games)
for run_index in range(self.total_games)
]
for future in as_completed(futures):
result = future.result()
self._deal_result(result)
self.pbar.update(1)

def _run_serial(self):
for _ in range(self.total_games):
result = self.game.run(self.env_fn, self.agents)
for run_index in range(self.total_games):
result = self.game.run(self.seed + run_index, self.env_fn, self.agents)
self._deal_result(result)
self.pbar.update(1)

def run(self, parallel: bool = True) -> Dict[str, Any]:
assert self.seed is not None, "Please call reset() to set seed first."
if parallel:
self._run_parallel()
else:
Expand Down
34 changes: 26 additions & 8 deletions openrl/arena/games/base_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,53 @@
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
from gymnasium.utils import seeding

from openrl.arena.agents.base_agent import BaseAgent
from openrl.selfplay.opponents.base_opponent import BaseOpponent


class BaseGame(ABC):
_np_random: Optional[np.random.Generator] = None

def __init__(self):
self.dispatch_func = None
self.seed = None

def reset(self, dispatch_func: Optional[Callable] = None):
if dispatch_func is not None:
self.dispatch_func = dispatch_func
else:
self.dispatch_func = self.default_dispatch_func
def reset(self, seed: int, dispatch_func: Optional[Callable] = None):
self.seed = seed
self._np_random, seed = seeding.np_random(seed)
if self.dispatch_func is None:
if dispatch_func is not None:
self.dispatch_func = dispatch_func
else:
self.dispatch_func = self.default_dispatch_func

def dispatch_agent_to_player(
self, players: List[str], agents: Dict[str, BaseAgent]
) -> Tuple[Dict[str, BaseOpponent], Dict[str, str]]:
assert self._np_random is not None
player2agent = {}
player2agent_name = self.dispatch_func(players, list(agents.keys()))
player2agent_name = self.dispatch_func(
self._np_random, players, list(agents.keys())
)
for player in players:
player2agent[player] = agents[player2agent_name[player]].new_agent()
return player2agent, player2agent_name

@staticmethod
def default_dispatch_func(
players: List[str], agent_names: List[str]
np_random: np.random.Generator,
players: List[str],
agent_names: List[str],
) -> Dict[str, str]:
raise NotImplementedError

def run(self, seed: int, env_fn: Callable, agents: List[BaseAgent]):
self.reset(seed=seed)
return self._run(env_fn, agents)

@abstractmethod
def run(self, env_fn, agents):
def _run(self, env_fn: Callable, agents: List[BaseAgent]):
raise NotImplementedError
16 changes: 11 additions & 5 deletions openrl/arena/games/two_player_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,32 @@

""""""
import random
from typing import Dict, List
from typing import Callable, Dict, List

import numpy as np

from openrl.arena.agents.base_agent import BaseAgent
from openrl.arena.games.base_game import BaseGame


class TwoPlayerGame(BaseGame):
@staticmethod
def default_dispatch_func(
players: List[str], agent_names: List[str]
np_random: np.random.Generator,
players: List[str],
agent_names: List[str],
) -> Dict[str, str]:
assert len(players) == len(
agent_names
), "The number of players must be equal to the number of agents."
assert len(players) == 2, "The number of players must be equal to 2."
random.shuffle(agent_names)
np_random.shuffle(agent_names)
return dict(zip(players, agent_names))

def run(self, env_fn, agents):
def _run(self, env_fn: Callable, agents: List[BaseAgent]):
env = env_fn()
env.reset()
env.reset(seed=self.seed)

player2agent, player2agent_name = self.dispatch_agent_to_player(
env.agents, agents
)
Expand Down
5 changes: 3 additions & 2 deletions openrl/envs/PettingZoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List, Optional, Union

from openrl.envs.common import build_envs
from openrl.envs.wrappers.pettingzoo_wrappers import SeedEnv


def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs):
Expand All @@ -37,7 +38,7 @@ def make_PettingZoo_env(
**kwargs,
):
env_num = 1
env_wrappers = []
env_wrappers = [SeedEnv]
env_wrappers += copy.copy(kwargs.pop("env_wrappers", []))
env_fns = build_envs(
make=PettingZoo_make,
Expand All @@ -62,7 +63,7 @@ def make_PettingZoo_envs(
Single2MultiAgentWrapper,
)

env_wrappers = copy.copy(kwargs.pop("opponent_wrappers", []))
env_wrappers = copy.copy(kwargs.pop("opponent_wrappers", [SeedEnv]))
env_wrappers += [
Single2MultiAgentWrapper,
RemoveTruncated,
Expand Down
12 changes: 11 additions & 1 deletion openrl/envs/wrappers/pettingzoo_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@
# limitations under the License.

""""""

from typing import Optional

from pettingzoo.utils.env import ActionType, AECEnv
from pettingzoo.utils.wrappers import BaseWrapper


class SeedEnv(BaseWrapper):
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed, options=options)

for i, space in enumerate(
list(self.action_spaces.values()) + list(self.observation_spaces.values())
):
space.seed(seed + i * 7891)


class RecordWinner(BaseWrapper):
def __init__(self, env: AECEnv):
super().__init__(env)
Expand Down

0 comments on commit 5c78ecf

Please sign in to comment.