diff --git a/Gallery.md b/Gallery.md index 0dd06f23..a29b2ad5 100644 --- a/Gallery.md +++ b/Gallery.md @@ -54,18 +54,19 @@ Users are also welcome to contribute their own training examples and demos to th
-| Environment/Demo | Tags | Refs | -|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------:|:-------------------------------:| -| [MuJoCo](https://github.com/deepmind/mujoco)
| ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/mujoco/) | -| [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) | -| [MPE: Simple Spread](https://pettingzoo.farama.org/environments/mpe/simple_spread/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) | -| [StarCraft II](https://github.com/oxwhirl/smac)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/smac/) | -| [Chat Bot](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![NLP](https://img.shields.io/badge/-NLP-green) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/nlp/) | -| [Atari Pong](https://gymnasium.farama.org/environments/atari/pong/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/atari/) | -| [PettingZoo: Tic-Tac-Toe](https://pettingzoo.farama.org/environments/classic/tictactoe/)
| ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) | -| [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
| ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/dm_control/) | -| [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/isaac/) | -| [GridWorld](./examples/gridworld/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) | -| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) | -| [Gym Retro](https://github.com/openai/retro)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/retro/) | +| Environment/Demo | Tags | Refs | +|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-------------------------------:| +| [MuJoCo](https://github.com/deepmind/mujoco)
| ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/mujoco/) | +| [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) | +| [MPE: Simple Spread](https://pettingzoo.farama.org/environments/mpe/simple_spread/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) | +| [StarCraft II](https://github.com/oxwhirl/smac)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/smac/) | +| [Chat Bot](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![NLP](https://img.shields.io/badge/-NLP-green) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/nlp/) | +| [Atari Pong](https://gymnasium.farama.org/environments/atari/pong/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/atari/) | +| [PettingZoo: Tic-Tac-Toe](https://pettingzoo.farama.org/environments/classic/tictactoe/)
| ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) | +| [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
| ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/dm_control/) | +| [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/isaac/) | +| [Snake](http://www.jidiai.cn/env_detail?envid=1)
| ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/snake/) | +| [GridWorld](./examples/gridworld/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) | +| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) | +| [Gym Retro](https://github.com/openai/retro)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/retro/) |
\ No newline at end of file diff --git a/README.md b/README.md index 2e4c4aaa..af7befcc 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,8 @@ Environments currently supported by OpenRL (for more details, please refer to [G - [Atari](https://gymnasium.farama.org/environments/atari/) - [StarCraft II](https://github.com/oxwhirl/smac) - [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs) -- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/) +- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/) +- [Snake](http://www.jidiai.cn/env_detail?envid=1) - [GridWorld](./examples/gridworld/) - [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros) - [Gym Retro](https://github.com/openai/retro) diff --git a/README_zh.md b/README_zh.md index cae5cea6..41822950 100644 --- a/README_zh.md +++ b/README_zh.md @@ -86,7 +86,8 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)): - [Atari](https://gymnasium.farama.org/environments/atari/) - [StarCraft II](https://github.com/oxwhirl/smac) - [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs) -- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/) +- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/) +- [Snake](http://www.jidiai.cn/env_detail?envid=1) - [GridWorld](./examples/gridworld/) - [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros) - [Gym Retro](https://github.com/openai/retro) diff --git a/docs/images/snakes_1v1.gif b/docs/images/snakes_1v1.gif new file mode 100644 index 00000000..03b40ab2 Binary files /dev/null and b/docs/images/snakes_1v1.gif differ diff --git a/examples/dm_control/train_ppo.py b/examples/dm_control/train_ppo.py index 2a3004a1..aa77222f 100644 --- a/examples/dm_control/train_ppo.py +++ b/examples/dm_control/train_ppo.py @@ -4,10 +4,9 @@ from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.envs.wrappers.base_wrapper import BaseWrapper -from openrl.envs.wrappers.extra_wrappers import GIFWrapper +from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent -from openrl.envs.wrappers.extra_wrappers import FrameSkip env_name = "dm_control/cartpole-balance-v0" # env_name = "dm_control/walker-walk-v0" diff --git a/examples/smac/README.md b/examples/smac/README.md index 5fb14e76..f9d7d2cb 100644 --- a/examples/smac/README.md +++ b/examples/smac/README.md @@ -11,4 +11,7 @@ Installation guide for Linux: Train SMAC with [MAPPO](https://arxiv.org/abs/2103.01955) algorithm: -`python train_ppo.py --config smac_ppo.yaml` \ No newline at end of file +`python train_ppo.py --config smac_ppo.yaml` + +## Render replay on Mac + diff --git a/examples/snake/README.md b/examples/snake/README.md new file mode 100644 index 00000000..4adb9cbd --- /dev/null +++ b/examples/snake/README.md @@ -0,0 +1,17 @@ + +This is the example for the snake game. + +## Usage + +```bash +python train_selfplay.py +``` + + +## Submit to JiDi + +Submition site: http://www.jidiai.cn/env_detail?envid=1. + +Snake senarios: [here](https://github.com/jidiai/ai_lib/blob/7a6986f0cb543994277103dbf605e9575d59edd6/env/config.json#L94) +Original Snake environment: [here](https://github.com/jidiai/ai_lib/blob/master/env/snakes.py) + diff --git a/examples/snake/selfplay.yaml b/examples/snake/selfplay.yaml new file mode 100644 index 00000000..74de97a0 --- /dev/null +++ b/examples/snake/selfplay.yaml @@ -0,0 +1,3 @@ +seed: 0 +callbacks: + - id: "ProgressBarCallback" diff --git a/examples/snake/submissions/random_agent/submission.py b/examples/snake/submissions/random_agent/submission.py new file mode 100644 index 00000000..b1f468df --- /dev/null +++ b/examples/snake/submissions/random_agent/submission.py @@ -0,0 +1,29 @@ +# -*- coding:utf-8 -*- +def sample_single_dim(action_space_list_each, is_act_continuous): + if is_act_continuous: + each = action_space_list_each.sample() + else: + if action_space_list_each.__class__.__name__ == "Discrete": + each = [0] * action_space_list_each.n + idx = action_space_list_each.sample() + each[idx] = 1 + elif action_space_list_each.__class__.__name__ == "MultiDiscreteParticle": + each = [] + nvec = action_space_list_each.high - action_space_list_each.low + 1 + sample_indexes = action_space_list_each.sample() + + for i in range(len(nvec)): + dim = nvec[i] + new_action = [0] * dim + index = sample_indexes[i] + new_action[index] = 1 + each.extend(new_action) + return each + + +def my_controller(observation, action_space, is_act_continuous): + joint_action = [] + for i in range(len(action_space)): + player = sample_single_dim(action_space[i], is_act_continuous) + joint_action.append(player) + return joint_action diff --git a/examples/snake/test_env.py b/examples/snake/test_env.py new file mode 100644 index 00000000..b3a6bbee --- /dev/null +++ b/examples/snake/test_env.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import time + +import numpy as np +from wrappers import ConvertObs + +from openrl.envs.snake.snake import SnakeEatBeans +from openrl.envs.snake.snake_pettingzoo import SnakeEatBeansAECEnv +from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper + + +def test_raw_env(): + env = SnakeEatBeans() + + obs, info = env.reset() + + done = False + while not np.any(done): + a1 = np.zeros(4) + a1[env.action_space.sample()] = 1 + a2 = np.zeros(4) + a2[env.action_space.sample()] = 1 + obs, reward, done, info = env.step([a1, a2]) + print("obs:", obs) + print("reward:", reward) + print("done:", done) + print("info:", info) + + +def test_aec_env(): + from PIL import Image + + img_list = [] + env = SnakeEatBeansAECEnv(render_mode="rgb_array") + env.reset(seed=0) + # time.sleep(1) + img = env.render() + img_list.append(img) + step = 0 + for player_name in env.agent_iter(): + if step > 20: + break + observation, reward, termination, truncation, info = env.last() + if termination or truncation: + break + action = env.action_space(player_name).sample() + # if player_name == "player_0": + # action = 2 + # elif player_name == "player_1": + # action = 3 + # else: + # raise ValueError("Unknown player name: {}".format(player_name)) + env.step(action) + img = env.render() + if player_name == "player_0": + img_list.append(img) + # time.sleep(1) + + step += 1 + print("Total steps: {}".format(step)) + + save_path = "test.gif" + img_list = [Image.fromarray(img) for img in img_list] + img_list[0].save(save_path, save_all=True, append_images=img_list[1:], duration=500) + + +def test_vec_env(): + from openrl.envs.common import make + + env = make( + "snakes_1v1", + opponent_wrappers=[ + RandomOpponentWrapper, + ], + env_wrappers=[ConvertObs], + render_mode="group_human", + env_num=2, + ) + obs, info = env.reset() + step = 0 + done = False + while not np.any(done): + action = env.random_action() + obs, reward, done, info = env.step(action) + time.sleep(0.3) + step += 1 + print("Total steps: {}".format(step)) + + +if __name__ == "__main__": + test_vec_env() diff --git a/examples/snake/train_selfplay.py b/examples/snake/train_selfplay.py new file mode 100644 index 00000000..d466abbe --- /dev/null +++ b/examples/snake/train_selfplay.py @@ -0,0 +1,87 @@ +import numpy as np +import torch +from wrappers import ConvertObs + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent +from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper + + +def train(): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "selfplay.yaml"]) + + # Create environment + env_num = 10 + render_model = None + env = make( + "snakes_1v1", + render_mode=render_model, + env_num=env_num, + asynchronous=True, + opponent_wrappers=[RandomOpponentWrapper], + env_wrappers=[ConvertObs], + cfg=cfg, + ) + # Create neural network + + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") + # Create agent + agent = Agent(net) + # Begin training + agent.train(total_time_steps=100000) + env.close() + agent.save("./selfplay_agent/") + return agent + + +def evaluation(): + from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender + + print("Evaluation...") + env_num = 1 + env = make( + "snakes_1v1", + env_num=env_num, + asynchronous=True, + opponent_wrappers=[RandomOpponentWrapper], + env_wrappers=[ConvertObs], + auto_reset=False, + ) + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") + + agent = Agent(net) + + agent.load("./selfplay_agent/") + agent.set_env(env) + env.reset(seed=0) + + total_reward = 0.0 + ep_num = 5 + for ep_now in range(ep_num): + obs, info = env.reset() + done = False + step = 0 + + while not np.any(done): + # predict next action based on the observation + action, _ = agent.act(obs, info, deterministic=True) + obs, r, done, info = env.step(action) + step += 1 + + if np.any(done): + total_reward += np.mean(r) > 0 + print(f"{ep_now}/{ep_num}: reward: {np.mean(r)}") + print(f"win rate: {total_reward/ep_num}") + env.close() + print("Evaluation finished.") + + +if __name__ == "__main__": + train() + evaluation() diff --git a/examples/snake/wrappers.py b/examples/snake/wrappers.py new file mode 100644 index 00000000..52f3958d --- /dev/null +++ b/examples/snake/wrappers.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +from openrl.envs.wrappers.base_wrapper import BaseObservationWrapper + + +def raw2vec(raw_obs, n_player=2): + control_index = raw_obs["controlled_snake_index"][0] + + width = raw_obs["board_width"][0] + height = raw_obs["board_height"][0] + beans = raw_obs[1][0] + + ally_pos = raw_obs[control_index][0] + enemy_pos = raw_obs[5 - control_index][0] + + obs = np.zeros(width * height * n_player, dtype=int) + + ally_head_h, ally_head_w = ally_pos[0] + enemy_head_h, enemy_head_w = enemy_pos[0] + obs[ally_head_h * width + ally_head_w] = 2 + obs[height * width + ally_head_h * width + ally_head_w] = 4 + obs[enemy_head_h * width + enemy_head_w] = 4 + obs[height * width + enemy_head_h * width + enemy_head_w] = 2 + + for bean in beans: + h, w = bean + obs[h * width + w] = 1 + obs[height * width + h * width + w] = 1 + + for p in ally_pos[1:]: + h, w = p + obs[h * width + w] = 3 + obs[height * width + h * width + w] = 5 + + for p in enemy_pos[1:]: + h, w = p + obs[h * width + w] = 5 + obs[height * width + h * width + w] = 3 + + obs_ = np.array([]) + for i in obs: + obs_ = np.concatenate([obs_, np.eye(6)[i]]) + obs_ = obs_.reshape(-1, width * height * n_player * 6) + + return obs_ + + +class ConvertObs(BaseObservationWrapper): + def __init__(self, env: gym.Env): + """Flattens the observations of an environment. + + Args: + env: The environment to apply the wrapper + """ + BaseObservationWrapper.__init__(self, env) + + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(576,), dtype=np.float32 + ) + + def observation(self, observation): + """Flattens an observation. + + Args: + observation: The observation to flatten + + Returns: + The flattened observation + """ + + return raw2vec(observation) diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py index ebd8d727..bbca547b 100644 --- a/openrl/algorithms/dqn.py +++ b/openrl/algorithms/dqn.py @@ -167,9 +167,7 @@ def prepare_loss( ) q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch - q_loss = torch.mean( - F.mse_loss(q_values, q_targets.detach()) - ) # 均方误差损失函数 + q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 loss_list.append(q_loss) diff --git a/openrl/algorithms/vdn.py b/openrl/algorithms/vdn.py index 83bdb5ed..f1215c03 100644 --- a/openrl/algorithms/vdn.py +++ b/openrl/algorithms/vdn.py @@ -211,9 +211,7 @@ def prepare_loss( rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1) rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1) q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch - q_loss = torch.mean( - F.mse_loss(q_values, q_targets.detach()) - ) # 均方误差损失函数 + q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 loss_list.append(q_loss) return loss_list diff --git a/openrl/envs/PettingZoo/__init__.py b/openrl/envs/PettingZoo/__init__.py index e5111afc..fa9e66ca 100644 --- a/openrl/envs/PettingZoo/__init__.py +++ b/openrl/envs/PettingZoo/__init__.py @@ -63,7 +63,8 @@ def make_PettingZoo_envs( Single2MultiAgentWrapper, ) - env_wrappers = copy.copy(kwargs.pop("opponent_wrappers", [SeedEnv])) + env_wrappers = [SeedEnv] + env_wrappers += copy.copy(kwargs.pop("opponent_wrappers", [])) env_wrappers += [ Single2MultiAgentWrapper, RemoveTruncated, diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 90f54e82..3a274c2c 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -65,7 +65,14 @@ def make( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) else: - if id.startswith("dm_control/"): + if id.startswith("snakes_"): + from openrl.envs.snake import make_snake_envs + + env_fns = make_snake_envs( + id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs + ) + + elif id.startswith("dm_control/"): from openrl.envs.dmc import make_dmc_envs env_fns = make_dmc_envs( diff --git a/openrl/envs/dmc/__init__.py b/openrl/envs/dmc/__init__.py index ad2b113d..4f6ff39e 100644 --- a/openrl/envs/dmc/__init__.py +++ b/openrl/envs/dmc/__init__.py @@ -13,10 +13,7 @@ def make_dmc_envs( render_mode: Optional[Union[str, List[str]]] = None, **kwargs, ): - from openrl.envs.wrappers import ( - RemoveTruncated, - Single2MultiAgentWrapper, - ) + from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper from openrl.envs.wrappers.extra_wrappers import ConvertEmptyBoxWrapper env_wrappers = copy.copy(kwargs.pop("env_wrappers", [])) diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py index 65ca66b0..ab1a47db 100644 --- a/openrl/envs/mpe/rendering.py +++ b/openrl/envs/mpe/rendering.py @@ -29,10 +29,12 @@ except ImportError: print( "Error occured while running `from pyglet.gl import *`", - "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" - " install python-opengl'. If you're running on a server, you may need a" - " virtual frame buffer; something like this should work: 'xvfb-run -s" - ' "-screen 0 1400x900x24" python \'', + ( + "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" + " install python-opengl'. If you're running on a server, you may need a" + " virtual frame buffer; something like this should work: 'xvfb-run -s" + ' "-screen 0 1400x900x24" python \'' + ), ) import math diff --git a/openrl/envs/snake/__init__.py b/openrl/envs/snake/__init__.py new file mode 100644 index 00000000..7d049e8f --- /dev/null +++ b/openrl/envs/snake/__init__.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import copy +from typing import List, Optional, Union + +from pettingzoo.utils.wrappers import AssertOutOfBoundsWrapper, OrderEnforcingWrapper + +from openrl.envs.common import build_envs +from openrl.envs.snake.snake_pettingzoo import SnakeEatBeansAECEnv +from openrl.envs.wrappers.pettingzoo_wrappers import SeedEnv + + +def snake_env_make(id, render_mode, disable_env_checker, **kwargs): + if id == "snakes_1v1": + env = SnakeEatBeansAECEnv(render_mode=render_mode) + else: + raise ValueError("Unknown env {}".format(id)) + return env + + +def make_snake_envs( + id: str, + env_num: int = 1, + render_mode: Optional[Union[str, List[str]]] = None, + **kwargs, +): + from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper + + env_wrappers = [AssertOutOfBoundsWrapper, OrderEnforcingWrapper, SeedEnv] + env_wrappers += copy.copy(kwargs.pop("opponent_wrappers", [])) + env_wrappers += [ + Single2MultiAgentWrapper, + RemoveTruncated, + ] + env_wrappers += copy.copy(kwargs.pop("env_wrappers", [])) + + env_fns = build_envs( + make=snake_env_make, + id=id, + env_num=env_num, + render_mode=render_mode, + wrappers=env_wrappers, + **kwargs, + ) + + return env_fns diff --git a/openrl/envs/snake/common.py b/openrl/envs/snake/common.py new file mode 100644 index 00000000..6a67a0a3 --- /dev/null +++ b/openrl/envs/snake/common.py @@ -0,0 +1,227 @@ +import os +import sys + +import numpy as np + + +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + +class Board: + def __init__(self, board_height, board_width, snakes, beans_positions, teams): + # print('create board, beans_position: ', beans_positions) + self.height = board_height + self.width = board_width + self.snakes = snakes + self.snakes_count = len(snakes) + self.beans_positions = beans_positions + self.blank_sign = -self.snakes_count + self.bean_sign = -self.snakes_count + 1 + self.board = np.zeros((board_height, board_width), dtype=int) + self.blank_sign + self.open = dict() + for key, snake in self.snakes.items(): + self.open[key] = [snake.head] # state 0 open list, heads, ready to spread + # see [A* Pathfinding (E01: algorithm explanation)](https://www.youtube.com/watch?v=-L-WgKMFuhE) + for x, y in snake.pos: + self.board[x][y] = key # obstacles, e.g. 0, 1, 2, 3, 4, 5 + # for x, y in beans_positions: + # self.board[x][y] = self.bean_sign # beans + + self.state = 0 + self.controversy = dict() + self.teams = teams + + # print('initial board') + # print(self.board) + + def step(self): # delay: prevent rear-end collision + new_open = {key: [] for key in self.snakes.keys()} + self.state += 1 # update state + # if self.state > delay: + # for key, snake in self.snakes.items(): # drop tail + # if snake.len >= self.state: + # self.board[snake.pos[-(self.state - delay)][0]][snake.pos[-(self.state - delay)][1]] \ + # = self.blank_sign + for key, snake in self.snakes.items(): + if snake.len >= self.state: + self.board[snake.pos[-self.state][0]][ + snake.pos[-self.state][1] + ] = self.blank_sign # drop tail + for key, value in self.open.items(): # value: e.g. [[8, 3], [6, 3], [7, 4]] + others_tail_pos = [ + ( + self.snakes[_].pos[-self.state] + if self.snakes[_].len >= self.state + else [] + ) + for _ in set(range(self.snakes_count)) - {key} + ] + for x, y in value: + # print('start to spread snake {} on grid ({}, {})'.format(key, x, y)) + for x_, y_ in [ + ((x + 1) % self.height, y), # down + ((x - 1) % self.height, y), # up + (x, (y + 1) % self.width), # right + (x, (y - 1) % self.width), + ]: # left + sign = self.board[x_][y_] + idx = ( + sign % self.snakes_count + ) # which snake, e.g. 0, 1, 2, 3, 4, 5 / number of claims + state = ( + sign // self.snakes_count + ) # manhattan distance to snake who claim the point or its negative + if sign == self.blank_sign: # grid in initial state + if [x_, y_] in others_tail_pos: + # print('do not spread other snakes tail, in case of rear-end collision') + continue # do not spread other snakes' tail, in case of rear-end collision + self.board[x_][y_] = self.state * self.snakes_count + key + self.snakes[key].claimed_count += 1 + new_open[key].append([x_, y_]) + + elif key != idx and self.state == state: + # second claim, init controversy, change grid value from + to - + # print( + # '\tgird ({}, {}) in the same state claimed by different snakes ' + # 'with sign {}, idx {} and state {}'.format( + # x_, y_, sign, idx, state)) + if ( + self.snakes[idx].len > self.snakes[key].len + ): # shorter snake claim the controversial grid + # print('\t\tsnake {} is shorter than snake {}'.format(key, idx)) + self.snakes[idx].claimed_count -= 1 + new_open[idx].remove([x_, y_]) + self.board[x_][y_] = self.state * self.snakes_count + key + self.snakes[key].claimed_count += 1 + new_open[key].append([x_, y_]) + elif ( + self.snakes[idx].len == self.snakes[key].len + ): # controversial claim + # print( + # '\t\tcontroversy! first claimed by snake {}, then claimed by snake {}'.format(idx, key)) + self.controversy[(x_, y_)] = { + "state": self.state, + "length": self.snakes[idx].len, + "indexes": [idx, key], + } + # first claim by snake idx, then claim by snake key + self.board[x_][y_] = -self.state * self.snakes_count + 1 + # if + 2, not enough for all snakes claim one grid!! + self.snakes[ + idx + ].claimed_count -= ( + 1 # controversy, no snake claim this grid!! + ) + new_open[key].append([x_, y_]) + else: # (self.snakes[idx].len < self.snakes[key].len) + pass # longer snake do not claim the controversial grid + + elif ( + (x_, y_) in self.controversy + and key not in self.controversy[(x_, y_)]["indexes"] + and self.state + state == 0 + ): # third claim or more + # print('snake {} meets third or more claim in grid ({}, {})'.format(key, x_, y_)) + controversy = self.controversy[(x_, y_)] + # pprint.pprint(controversy) + if ( + controversy["length"] > self.snakes[key].len + ): # shortest snake claim grid, do 4 things + # print('\t\tsnake {} is shortest'.format(key)) + indexes_count = len(controversy["indexes"]) + for i in controversy["indexes"]: + self.snakes[i].claimed_count -= ( + 1 / indexes_count + ) # update claimed_count ! + new_open[i].remove([x_, y_]) + del self.controversy[(x_, y_)] + self.board[x_][y_] = self.state * self.snakes_count + key + self.snakes[key].claimed_count += 1 + new_open[key].append([x_, y_]) + elif ( + controversy["length"] == self.snakes[key].len + ): # controversial claim + # print('\t\tcontroversy! multi claimed by snake {}'.format(key)) + self.controversy[(x_, y_)]["indexes"].append(key) + self.board[x_][y_] += 1 + new_open[key].append([x_, y_]) + else: # (controversy['length'] < self.snakes[key].len) + pass # longer snake do not claim the controversial grid + else: + pass # do nothing with lower state grids + + self.open = new_open # update open + # update controversial snakes' claimed_count (in fraction) in the end + for _, d in self.controversy.items(): + controversial_snake_count = len( + d["indexes"] + ) # number of controversial snakes + for idx in d["indexes"]: + self.snakes[idx].claimed_count += 1 / controversial_snake_count + + +class SnakePos: + def __init__(self, snake_positions, board_height, board_width, beans_positions): + self.pos = snake_positions # [[2, 9], [2, 8], [2, 7]] + self.len = len(snake_positions) # >= 3 + self.head = snake_positions[0] + self.beans_positions = beans_positions + self.claimed_count = 0 + + displace = [ + (self.head[0] - snake_positions[1][0]) % board_height, + (self.head[1] - snake_positions[1][1]) % board_width, + ] + # print('creat snake, pos: ', self.pos, 'displace:', displace) + if displace == [ + board_height - 1, + 0, + ]: # all action are ordered by left, up, right, relative to the body + self.dir = 0 # up + self.legal_action = [2, 0, 3] + elif displace == [1, 0]: + self.dir = 1 # down + self.legal_action = [3, 1, 2] + elif displace == [0, board_width - 1]: + self.dir = 2 # left + self.legal_action = [1, 2, 0] + elif displace == [0, 1]: + self.dir = 3 # right + self.legal_action = [0, 3, 1] + else: + assert False, "snake positions error" + positions = [ + [(self.head[0] - 1) % board_height, self.head[1]], + [(self.head[0] + 1) % board_height, self.head[1]], + [self.head[0], (self.head[1] - 1) % board_width], + [self.head[0], (self.head[1] + 1) % board_width], + ] + self.legal_position = [positions[_] for _ in self.legal_action] + + def get_action(self, position): + if position not in self.legal_position: + assert False, "the start and end points do not match" + idx = self.legal_position.index(position) + return self.legal_action[idx] # 0, 1, 2, 3: up, down, left, right + + def step(self, legal_input): + if legal_input in self.legal_position: + position = legal_input + elif legal_input in self.legal_action: + idx = self.legal_action.index(legal_input) + position = self.legal_position[idx] + else: + assert False, "illegal snake move" + self.head = position + self.pos.insert(0, position) + if position in self.beans_positions: # eat a bean + self.len += 1 + else: # do not eat a bean + self.pos.pop() diff --git a/openrl/envs/snake/discrete.py b/openrl/envs/snake/discrete.py new file mode 100644 index 00000000..7d6d318d --- /dev/null +++ b/openrl/envs/snake/discrete.py @@ -0,0 +1,35 @@ +import numpy as np + +from .space import Space + + +class Discrete(Space): + r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`. + Example:: + >>> Discrete(2) + """ + + def __init__(self, n): + assert n >= 0 + self.n = n + super(Discrete, self).__init__((), np.int64) + + def sample(self): + return self.np_random.randint(self.n) + + def contains(self, x): + if isinstance(x, int): + as_int = x + elif isinstance(x, (np.generic, np.ndarray)) and ( + x.dtype.char in np.typecodes["AllInteger"] and x.shape == () + ): + as_int = int(x) + else: + return False + return as_int >= 0 and as_int < self.n + + def __repr__(self): + return "Discrete(%d)" % self.n + + def __eq__(self, other): + return isinstance(other, Discrete) and self.n == other.n diff --git a/openrl/envs/snake/game.py b/openrl/envs/snake/game.py new file mode 100644 index 00000000..c0e35d39 --- /dev/null +++ b/openrl/envs/snake/game.py @@ -0,0 +1,55 @@ +# -*- coding:utf-8 -*- +# 作者:zruizhi +# 创建时间: 2020/7/10 10:24 上午 +# 描述: +from abc import ABC, abstractmethod + + +class Game(ABC): + def __init__( + self, + n_player, + is_obs_continuous, + is_act_continuous, + game_name, + agent_nums, + obs_type, + ): + self.n_player = n_player + self.current_state = None + self.all_observes = None + self.is_obs_continuous = is_obs_continuous + self.is_act_continuous = is_act_continuous + self.game_name = game_name + self.agent_nums = agent_nums + self.obs_type = obs_type + + def get_config(self, player_id): + raise NotImplementedError + + def get_render_data(self, current_state): + return current_state + + def set_current_state(self, current_state): + raise NotImplementedError + + @abstractmethod + def is_terminal(self): + raise NotImplementedError + + def get_next_state(self, all_action): + raise NotImplementedError + + def get_reward(self, all_action): + raise NotImplementedError + + @abstractmethod + def step(self, all_action): + raise NotImplementedError + + @abstractmethod + def reset(self): + raise NotImplementedError + + def set_action_space(self): + raise NotImplementedError diff --git a/openrl/envs/snake/gridgame.py b/openrl/envs/snake/gridgame.py new file mode 100644 index 00000000..b75aa44d --- /dev/null +++ b/openrl/envs/snake/gridgame.py @@ -0,0 +1,266 @@ +# -*- coding:utf-8 -*- +# 作者:zruizhi +# 创建时间: 2020/7/10 10:24 上午 +# 描述: + +from itertools import count + +import numpy as np +from PIL import Image, ImageDraw + +from .game import Game + +UNIT = 40 +FIX = 8 + + +class GridGame(Game): + def __init__(self, conf, colors=None, unit_size=UNIT, fix=FIX): + super().__init__( + conf["n_player"], + conf["is_obs_continuous"], + conf["is_act_continuous"], + conf["game_name"], + conf["agent_nums"], + conf["obs_type"], + ) + # grid game conf + self.game_name = conf["game_name"] + self.max_step = int(conf["max_step"]) + self.board_width = int(conf["board_width"]) + self.board_height = int(conf["board_height"]) + self.cell_range = ( + conf["cell_range"] + if isinstance(eval(str(conf["cell_range"])), tuple) + else (int(conf["cell_range"]),) + ) + self.cell_dim = len(self.cell_range) + self.cell_size = np.prod(self.cell_range) + + # grid observation conf + self.ob_board_width = ( + conf["ob_board_width"] + if conf.get("ob_board_width") is not None + else [self.board_width for _ in range(self.n_player)] + ) + self.ob_board_height = ( + conf["ob_board_height"] + if conf.get("ob_board_height") is not None + else [self.board_height for _ in range(self.n_player)] + ) + self.ob_cell_range = ( + conf["ob_cell_range"] + if conf.get("ob_cell_range") is not None + else [self.cell_range for _ in range(self.n_player)] + ) + + # vector observation conf + self.ob_vector_shape = ( + conf["ob_vector_shape"] + if conf.get("ob_vector_shape") is not None + else [ + self.board_width * self.board_height * self.cell_dim + for _ in range(self.n_player) + ] + ) + self.ob_vector_range = ( + conf["ob_vector_range"] + if conf.get("ob_vector_range") is not None + else [self.cell_range for _ in range(self.n_player)] + ) + + # 每个玩家的 action space list, 可以根据player_id获取对应的single_action_space + self.joint_action_space = self.set_action_space() + + # global state,每个step需维护此项,并根据此项定义render data 及 observation + self.current_state = None + + # 记录对局结果信息 + self.n_return = [0] * self.n_player + self.won = "" + + # render 相关 + self.grid_unit = unit_size + self.grid = GridGame.init_board(self.board_width, self.board_height, unit_size) + self.grid_unit_fix = fix + self.colors = ( + colors + generate_color(self.cell_size - len(colors) + 1) + if colors is not None + else generate_color(self.cell_size) + ) + self.init_info = None + + def get_grid_obs_config(self, player_id): + return ( + self.ob_board_width[player_id], + self.ob_board_height[player_id], + self.ob_cell_range[player_id], + ) + + def get_grid_many_obs_space(self, player_id_list): + all_obs_space = {} + for i in player_id_list: + m, n, r_l = self.get_grid_obs_config(i) + all_obs_space[i] = (m, n, len(r_l)) + return all_obs_space + + def get_vector_obs_config(self, player_id): + return self.ob_vector_shape[player_id], self.ob_vector_range[player_id] + + def get_vector_many_obs_space(self, player_id_list): + all_obs_space = {} + for i in player_id_list: + m = self.ob_vector_shape[i] + all_obs_space[i] = m + return all_obs_space + + def get_single_action_space(self, player_id): + return self.joint_action_space[player_id] + + def set_action_space(self): + raise NotImplementedError + + def check_win(self): + raise NotImplementedError + + def get_render_data(self, current_state): + grid_map = [[0] * self.board_width for _ in range(self.board_height)] + for i in range(self.board_height): + for j in range(self.board_width): + grid_map[i][j] = 0 + for k in range(self.cell_dim): + grid_map[i][j] = ( + grid_map[i][j] * self.cell_range[k] + current_state[i][j][k] + ) + return grid_map + + def set_current_state(self, current_state): + if not current_state: + raise NotImplementedError + + self.current_state = current_state + + def is_not_valid_action(self, joint_action): + raise NotImplementedError + + def is_not_valid_grid_observation(self, obs, player_id): + not_valid = 0 + w, h, cell_range = self.get_grid_obs_config(player_id) + if len(obs) != h or len(obs[0]) != w or len(obs[0][0]) != len(cell_range): + raise Exception("obs 维度不正确!", obs) + + for i in range(h): + for j in range(w): + for k in range(len(cell_range)): + if obs[i][j][k] not in range(cell_range[k]): + raise Exception("obs 单元值不正确!", obs[i][j][k]) + + return not_valid + + def is_not_valid_vector_observation(self, obs, player_id): + not_valid = 0 + shape, vector_range = self.get_vector_obs_config(player_id) + if len(obs) != shape or len(vector_range) != shape: + raise Exception("obs 维度不正确!", obs) + + for i in range(shape): + if obs[i] not in range(vector_range[i]): + raise Exception("obs 单元值不正确!", obs[i]) + + return not_valid + + def step(self, joint_action): + info_before = self.step_before_info() + all_observes, info_after = self.get_next_state(joint_action) + done = self.is_terminal() + reward = self.get_reward(joint_action) + return all_observes, reward, done, info_before, info_after + + def step_before_info(self, info=""): + return info + + def init_action_space(self): + joint_action = [] + for i in range(len(self.joint_action_space)): + player = [] + for j in range(len(self.joint_action_space[i])): + each = [0] * self.joint_action_space[i][j].n + player.append(each) + joint_action.append(player) + return joint_action + + def draw_board(self): + cols = [chr(i) for i in range(65, 65 + self.board_width)] + s = ", ".join(cols) + print(" ", s) + for i in range(self.board_height): + print(chr(i + 65), self.current_state[i]) + + def render_board(self): + im_data = np.array( + GridGame._render_board( + self.get_render_data(self.current_state), + self.grid, + self.colors, + self.grid_unit, + self.grid_unit_fix, + ) + ) + return im_data + + @staticmethod + def init_board(width, height, grid_unit, color=(250, 235, 215)): + im = Image.new( + mode="RGB", size=(width * grid_unit, height * grid_unit), color=color + ) + draw = ImageDraw.Draw(im) + for x in range(0, width): + draw.line( + ((x * grid_unit, 0), (x * grid_unit, height * grid_unit)), + fill=(105, 105, 105), + ) + for y in range(0, height): + draw.line( + ((0, y * grid_unit), (width * grid_unit, y * grid_unit)), + fill=(105, 105, 105), + ) + return im + + @staticmethod + def _render_board(state, board, colors, unit, fix, extra_info=None): + """ + 完成基本渲染棋盘操作 + 设置extra_info参数仅为了保持子类方法签名的一致 + """ + im = board.copy() + draw = ImageDraw.Draw(im) + for x, row in zip(count(0), state): + for y, state in zip(count(0), row): + if state == 0: + continue + draw.rectangle( + build_rectangle(y, x, unit, fix), + fill=tuple(colors[state]), + outline=(192, 192, 192), + ) + return im + + @staticmethod + def parse_extra_info(data): + return None + + +def build_rectangle(x, y, unit_size=UNIT, fix=FIX): + return ( + x * unit_size + unit_size // fix, + y * unit_size + unit_size // fix, + (x + 1) * unit_size - unit_size // fix, + (y + 1) * unit_size - unit_size // fix, + ) + + +def generate_color(n): + return [ + tuple(map(lambda n: int(n), np.random.choice(range(256), size=3))) + for _ in range(n) + ] diff --git a/openrl/envs/snake/observation.py b/openrl/envs/snake/observation.py new file mode 100644 index 00000000..6e28b37f --- /dev/null +++ b/openrl/envs/snake/observation.py @@ -0,0 +1,61 @@ +# -*- coding:utf-8 -*- +# 作者:zruizhi +# 创建时间: 2020/11/13 3:51 下午 +# 描述:observation的各种接口类 +obs_type = ["grid", "vector", "dict"] + + +class GridObservation(object): + def get_grid_observation(self, current_state, player_id, info_before): + raise NotImplementedError + + def get_grid_many_observation(self, current_state, player_id_list, info_before=""): + all_obs = [] + for i in player_id_list: + all_obs.append(self.get_grid_observation(current_state, i, info_before)) + return all_obs + + +class VectorObservation(object): + def get_vector_observation(self, current_state, player_id, info_before): + raise NotImplementedError + + def get_vector_many_observation( + self, current_state, player_id_list, info_before="" + ): + all_obs = [] + for i in player_id_list: + all_obs.append(self.get_vector_observation(current_state, i, info_before)) + return all_obs + + +class DictObservation(object): + def get_dict_observation(self, current_state, player_id, info_before): + raise NotImplementedError + + def get_dict_many_observation(self, current_state, player_id_list, info_before=""): + all_obs = [] + for i in player_id_list: + all_obs.append(self.get_dict_observation(current_state, i, info_before)) + return all_obs + + +# todo: observation builder +class CustomObservation(object): + def get_custom_observation(self, current_state, player_id): + raise NotImplementedError + + def get_custom_obs_space(self, player_id): + raise NotImplementedError + + def get_custom_many_observation(self, current_state, player_id_list): + all_obs = [] + for i in player_id_list: + all_obs.append(self.get_custom_observation(current_state, i)) + return all_obs + + def get_custom_many_obs_space(self, player_id_list): + all_obs_space = [] + for i in player_id_list: + all_obs_space.append(self.get_custom_obs_space(i)) + return all_obs_space diff --git a/openrl/envs/snake/snake.py b/openrl/envs/snake/snake.py new file mode 100644 index 00000000..84e09f8b --- /dev/null +++ b/openrl/envs/snake/snake.py @@ -0,0 +1,718 @@ +# -*- coding:utf-8 -*- +# 作者:zruizhi +# 创建时间: 2020/7/30 17:24 下午 +# 描述: +import itertools +import random +from itertools import count +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +from gym import Env, spaces +from PIL import Image, ImageDraw, ImageFont + +from .discrete import Discrete +from .gridgame import GridGame, generate_color +from .observation import * + + +def convert_to_onehot(joint_action): + new_joint_action = [] + for action in joint_action: + onehot_action = np.zeros(4) + onehot_action[action] = 1 + new_joint_action.append(onehot_action) + return new_joint_action + + +class SnakeEatBeans(GridGame, GridObservation, DictObservation): + def __init__(self, render_mode: Optional[str] = None): + conf = { + "class_literal": "SnakeEatBeans", + "n_player": 2, + "board_width": 8, + "board_height": 6, + "cell_range": 4, + "n_beans": 5, + "max_step": 50, + "game_name": "snakes", + "is_obs_continuous": False, + "is_act_continuous": False, + "agent_nums": [1, 1], + "obs_type": ["dict", "dict"], + "save_interval": 100, + "save_path": "../../replay_winrate_var/replay_{}.gif", + } + self.terminate_flg = False + colors = conf.get("colors", [(255, 255, 255), (255, 140, 0)]) + super(SnakeEatBeans, self).__init__(conf, colors) + # 0: 没有 1:食物 2-n_player+1:各玩家蛇身 + self.n_cell_type = self.n_player + 2 + self.step_cnt = 1 + self.n_beans = int(conf["n_beans"]) + # 方向[-2,2,-1,1]分别表示[上,下,左,右] + self.actions = [-2, 2, -1, 1] + self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} + self.snakes_position = {} + self.players = [] + self.cur_bean_num = 0 + self.beans_position = [] + # 1<= init_len <= 3 + self.init_len = 3 + self.current_state = self.init_state() + self.all_observes = self.get_all_observes() + if self.n_player * self.init_len > self.board_height * self.board_width: + raise Exception( + "玩家数量过多:%d,超出board范围:%d,%d" + % (self.n_player, self.board_width, self.board_height) + ) + + self.input_dimension = self.board_width * self.board_height + self.action_dim = self.get_action_dim() + + self.num_agents = conf["agent_nums"][0] + self.num_enemys = conf["agent_nums"][1] + + self.observation_space = [ + spaces.Box(low=-np.inf, high=-np.inf, shape=(288,), dtype=np.float32) + ] + self.share_observation_space = [ + spaces.Box(low=-np.inf, high=+np.inf, shape=(288,), dtype=np.float32) + ] + # self.action_space = [Discrete(4) for _ in range(self.n_player)] + self.action_space = Discrete(4) + self.save_internal = conf["save_interval"] + self.save_path = conf["save_path"] + self.episode = 0 + self.fig, self.ax = None, None + if render_mode in ["human", "rgb_array"]: + self.need_render = True + if render_mode == "human": + plt.ion() + self.fig, self.ax = plt.subplots() + else: + self.need_render = False + self.render_mode = render_mode + self.img_list = None + self.render_img = None + + self.init_colors = colors + self.colors = None + + def seed(self, seed=None): + if seed is None: + np.random.seed(0) + random.seed(0) + else: + np.random.seed(seed) + random.seed(seed) + + def check_win(self): + flg = self.won.index(max(self.won)) + 2 + return flg + + def get_grid_observation(self, current_state, player_id, info_before): + return current_state + + def get_dict_observation(self, current_state, player_id, info_before): + key_info = {1: self.beans_position} + for i in range(self.n_player): + snake = self.players[i] + key_info[snake.player_id] = snake.segments + # key_info['state_map'] = current_state + key_info["board_width"] = self.board_width + key_info["board_height"] = self.board_height + key_info["last_direction"] = ( + info_before.get("directions") if isinstance(info_before, dict) else None + ) + key_info["controlled_snake_index"] = player_id + + return key_info + + def set_action_space(self): + action_space = [[Discrete(4)] for _ in range(self.n_player)] + return action_space + + def reset(self): + if self.need_render: + self.img_list = [] + self.render_img = None + if self.colors is None: + self.colors = ( + self.init_colors + + generate_color(self.cell_size - len(self.init_colors) + 1) + if self.init_colors is not None + else generate_color(self.cell_size) + ) + self.step_cnt = 1 + self.snakes_position = ( + {} + ) # 格式类似于{1: [[3, 1], [4, 3], [1, 2], [0, 6], [3, 3]], 2: [[3, 0], [3, 7], [3, 6]], 3: [[2, 7], [1, 7], [0, 7]]} + self.players = [] + self.cur_bean_num = 0 + self.beans_position = [] + self.current_state = self.init_state() + self.all_observes = self.get_all_observes() + self.terminate_flg = False + + self.episode += 1 + + # available actions + left_avail_actions = np.ones([self.num_agents, self.action_dim]) + right_avail_actions = np.ones([self.num_enemys, self.action_dim]) + avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0) + # process obs + + info = {"action_mask": avail_actions} + self.inner_render() + return self.all_observes, info + + def step(self, joint_action): + if np.array(joint_action).shape == (2,): + joint_action = convert_to_onehot(joint_action) + + joint_action = np.expand_dims(joint_action, 1) + all_observes, info_after = self.get_next_state(joint_action) + done = self.is_terminal() + reward = self.get_reward(joint_action) + + left_avail_actions = np.ones([self.num_agents, self.action_dim]) + right_avail_actions = np.ones([self.num_enemys, self.action_dim]) + avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0) + + raw_obs = all_observes[0] + obs = self.raw2vec(raw_obs) + + rewards = np.expand_dims(np.array(reward), axis=1) + + dones = [done] * self.n_player + infos = info_after + + infos.update({"action_mask": avail_actions}) + self.inner_render() + return self.all_observes, rewards, dones, infos + + # obs: 0-空白 1-豆子 2-我方蛇头 3-我方蛇身 4-敌方蛇头 5-敌方蛇身 + + def inner_render(self): + if not self.need_render: + return + img = self.render_board() + self.render_img = img + if self.render_mode == "human": + self.ax.imshow(img, cmap="gray") + plt.draw() + plt.pause(0.1) + + def render(self): + return self.render_img + + def raw2vec(self, raw_obs): + control_index = raw_obs["controlled_snake_index"] + width = raw_obs["board_width"] + height = raw_obs["board_height"] + beans = raw_obs[1] + ally_pos = raw_obs[control_index] + enemy_pos = raw_obs[5 - control_index] + + obs = np.zeros(width * height * self.n_player, dtype=int) + ally_head_h, ally_head_w = ally_pos[0] + enemy_head_h, enemy_head_w = enemy_pos[0] + obs[ally_head_h * width + ally_head_w] = 2 + obs[height * width + ally_head_h * width + ally_head_w] = 4 + obs[enemy_head_h * width + enemy_head_w] = 4 + obs[height * width + enemy_head_h * width + enemy_head_w] = 2 + + for bean in beans: + h, w = bean + obs[h * width + w] = 1 + obs[height * width + h * width + w] = 1 + + for p in ally_pos[1:]: + h, w = p + obs[h * width + w] = 3 + obs[height * width + h * width + w] = 5 + + for p in enemy_pos[1:]: + h, w = p + obs[h * width + w] = 5 + obs[height * width + h * width + w] = 3 + + obs_ = np.array([]) + for i in obs: + obs_ = np.concatenate([obs_, np.eye(6)[i]]) + obs_ = obs_.reshape(-1, width * height * 6) + + return obs_ + + def init_state(self): + for i in range(self.n_player): + s = Snake(i + 2, self.board_width, self.board_height, self.init_len) + s_len = 1 + while s_len < self.init_len: + if s_len == 1 and i > 0: + origin_hit = self.is_hit(s.headPos, self.snakes_position) + else: + origin_hit = 0 + cur_head = s.move_and_add(self.snakes_position) + cur_hit = self.is_hit(cur_head, self.snakes_position) or self.is_hit( + cur_head, {i: s.segments[1:]} + ) + if origin_hit or cur_hit: + x = random.randrange(0, self.board_height) + y = random.randrange(0, self.board_width) + s.headPos = [x, y] + s.segments = [s.headPos] + s.direction = random.choice(self.actions) + s_len = 1 + else: + s_len += 1 + self.snakes_position[s.player_id] = s.segments + self.players.append(s) + + self.generate_beans() + self.init_info = { + "snakes_position": [ + list(v) + for k, v in sorted( + self.snakes_position.items(), key=lambda item: item[0] + ) + ], + "beans_position": list(self.beans_position), + } + directs = [] + for i in range(len(self.players)): + s = self.players[i] + directs.append(self.actions_name[s.direction]) + self.init_info["directions"] = directs + + return self.update_state() + + def update_state(self): + next_state = [ + [[0] * self.cell_dim for _ in range(self.board_width)] + for _ in range(self.board_height) + ] + for i in range(self.n_player): + snake = self.players[i] + for pos in snake.segments: + next_state[pos[0]][pos[1]][0] = i + 2 + + for pos in self.beans_position: + next_state[pos[0]][pos[1]][0] = 1 + + return next_state + + def step_before_info(self, info=""): + directs = [] + for i in range(len(self.players)): + s = self.players[i] + directs.append(self.actions_name[s.direction]) + info = {"directions": directs} + + return info + + def is_hit(self, cur_head, snakes_position): + is_hit = False + for k, v in snakes_position.items(): + for pos in v: + if cur_head == pos: + is_hit = True + # print("hit:", cur_head, snakes_position) + break + if is_hit: + break + + return is_hit + + def generate_beans(self): + all_valid_positions = set( + itertools.product(range(0, self.board_height), range(0, self.board_width)) + ) + all_valid_positions = all_valid_positions - set(map(tuple, self.beans_position)) + for positions in self.snakes_position.values(): + all_valid_positions = all_valid_positions - set(map(tuple, positions)) + + left_bean_num = self.n_beans - self.cur_bean_num + all_valid_positions = np.array(list(all_valid_positions)) + left_valid_positions = len(all_valid_positions) + + new_bean_num = ( + left_bean_num + if left_valid_positions > left_bean_num + else left_valid_positions + ) + + if left_valid_positions > 0: + new_bean_positions_idx = np.random.choice( + left_valid_positions, size=new_bean_num, replace=False + ) + new_bean_positions = all_valid_positions[new_bean_positions_idx] + else: + new_bean_positions = [] + + for new_bean_pos in new_bean_positions: + self.beans_position.append(list(new_bean_pos)) + self.cur_bean_num += 1 + + def get_all_observes(self, before_info=""): + self.all_observes = [] + for i in range(self.n_player): + each_obs = self.get_dict_observation(self.current_state, i + 2, before_info) + self.all_observes.append(each_obs) + + return self.all_observes + + def get_next_state(self, all_action): + before_info = self.step_before_info() + not_valid = self.is_not_valid_action(all_action) + if not not_valid: + # 各玩家行动 + # print("current_state", self.current_state) + eat_snakes = [0] * self.n_player + others_reward = [ + 0 + ] * self.n_player # 记录对方获得的奖励,因为是零和博弈,所以敌人获得了多少奖励,我方就要减去多少奖励 + for i in range(self.n_player): # 判断是否吃到了豆子 + snake = self.players[i] + act = self.actions[np.argmax(all_action[i][0])] + # print(snake.player_id, "此轮的动作为:", self.actions_name[act]) + snake.change_direction(act) + snake.move_and_add(self.snakes_position) # 更新snake.segment + if self.be_eaten(snake.headPos): # @yanxue + snake.snake_reward = 1 + eat_snakes[i] = 1 + else: + snake.snake_reward = 0 + snake.pop() + # print(snake.player_id, snake.segments) # @yanxue + snake_position = [[-1] * self.board_width for _ in range(self.board_height)] + re_generatelist = [0] * self.n_player + for i in range(self.n_player): # 判断是否相撞 + snake = self.players[i] + segment = snake.segments + for j in range(len(segment)): + x = segment[j][0] + y = segment[j][1] + if snake_position[x][y] != -1: + if j == 0: # 撞头 + re_generatelist[i] = 1 + compare_snake = self.players[snake_position[x][y]] + if [x, y] == compare_snake.segments[0]: # 两头相撞won + re_generatelist[snake_position[x][y]] = 1 + else: + snake_position[x][y] = i + for i in range(self.n_player): + snake = self.players[i] + if re_generatelist[i] == 1: + if eat_snakes[i] == 1: + snake.snake_reward = ( + self.init_len - len(snake.segments) + 1 + ) # 身体越长,惩罚越大 + else: + snake.snake_reward = self.init_len - len(snake.segments) + snake.segments = [] + for i in range(self.num_agents): + others_reward[self.num_agents :] = [ + others_reward[j + self.num_agents] + self.players[i].snake_reward + for j in range(self.num_enemys) + ] + others_reward[self.num_agents :] = [ + others_reward[j + self.num_agents] // self.num_agents + for j in range(self.num_enemys) + ] + for i in range(self.num_enemys): + others_reward[: self.num_agents] = [ + others_reward[j] + self.players[i + self.num_agents].snake_reward + for j in range(self.num_agents) + ] + others_reward[: self.num_agents] = [ + others_reward[j] // self.num_enemys for j in range(self.num_agents) + ] + for i in range(self.n_player): + self.players[i].snake_reward -= others_reward[i] + for i in range(self.n_player): + snake = self.players[i] + if re_generatelist[i] == 1: + snake = self.clear_or_regenerate(snake) + self.snakes_position[snake.player_id] = snake.segments + snake.score = snake.get_score() + # yanxue add + # 更新状态 + self.generate_beans() + + next_state = self.update_state() + self.current_state = next_state + self.step_cnt += 1 + + self.won = [0] * self.n_player + + for i in range(self.n_player): + s = self.players[i] + self.won[i] = s.score + info_after = {} + info_after["snakes_position"] = [ + list(v) + for k, v in sorted( + self.snakes_position.items(), key=lambda item: item[0] + ) + ] + info_after["beans_position"] = list(self.beans_position) + info_after["hit"] = re_generatelist + info_after["score"] = self.won + self.all_observes = self.get_all_observes(before_info) + + return self.all_observes, info_after + + def clear_or_regenerate(self, snake): + direct_x = [0, 1, -1, 0] + direct_y = [1, 0, 0, -1] + snake.segments = [] + snake.score = 0 + grid = self.get_render_data(self.update_state()) + + def can_regenerate(): + for x in range(self.board_height): + for y in range(self.board_width): + if grid[x][y] == 0: + q = [] + q.append([x, y]) + seg = [] + while q: + cur = q.pop(0) + if cur not in seg: + seg.append(cur) + for i in range(4): + nx = (direct_x[i] + cur[0]) % self.board_height + ny = (direct_y[i] + cur[1]) % self.board_width + # if nx < 0 or nx >= self.board_height or ny < 0 or ny >= self.board_width: + # continue + if grid[nx][ny] == 0 and [nx, ny] not in q: + grid[nx][ny] = 1 + q.append([nx, ny]) + if len(seg) == self.init_len: + # print("regenerate") + if len(seg) < 3: + snake.direction = random.choice(self.actions) + elif len(seg) == 3: + mid = ( + [seg[1][0], seg[2][1]], + [seg[2][0], seg[1][1]], + ) + if seg[0] in mid: + seg[0], seg[1] = seg[1], seg[0] + snake.segments = seg + snake.headPos = seg[0] + if seg[0][0] == seg[1][0]: + # 右 + if seg[0][1] > seg[1][1]: + snake.direction = 1 + # 左 + else: + snake.direction = -1 + elif seg[0][1] == seg[1][1]: + # 下 + if seg[0][0] > seg[1][0]: + snake.direction = 2 + # 上 + else: + snake.direction = -2 + # print("re head", snake.headPos) # 输出重新生成的蛇 + # print("re snakes segments", snake.segments) + return True + # print("clear") + return False + + flg = can_regenerate() + if not flg: + self.terminate_flg = True + # print(self.terminate_flg) + return snake + + # def is_not_valid_action(self, joint_action): + # not_valid = 0 + # if len(joint_action) != self.n_player: + # raise Exception("joint action 维度不正确!", len(joint_action)) + # + # for i in range(len(joint_action)): + # if len(joint_action[i][0]) != 4: + # raise Exception("玩家%d joint action维度不正确!" % i, joint_action[i]) + # return not_valid + + def is_not_valid_action(self, all_action): + not_valid = 0 + if len(all_action) != self.n_player: + raise Exception("all action 维度不正确!", len(all_action)) + + for i in range(self.n_player): + if len(all_action[i][0]) != 4: + raise Exception("玩家%d joint action维度不正确!" % i, all_action[i]) + return not_valid + + def get_reward(self, all_action): + r = [0] * self.n_player + for i in range(self.n_player): + r[i] = self.players[i].snake_reward + self.n_return[i] += r[i] + # print("score:", self.won) + return r + + def is_terminal(self): + all_member = self.n_beans + # all_member = len(self.beans_position) + for s in self.players: + all_member += len(s.segments) + is_done = ( + self.step_cnt > self.max_step + or all_member > self.board_height * self.board_width + ) + + return is_done or self.terminate_flg + + def encode(self, actions): + joint_action = self.init_action_space() + if len(actions) != self.n_player: + raise Exception("action输入维度不正确!", len(actions)) + for i in range(self.n_player): + joint_action[i][0][int(actions[i])] = 1 + return joint_action + + def get_terminal_actions(self): + print("请输入%d个玩家的动作方向[0-3](上下左右),空格隔开:" % self.n_player) + cur = input() + actions = cur.split(" ") + return self.encode(actions) + + def be_eaten(self, snake_pos): + for bean in self.beans_position: + if snake_pos[0] == bean[0] and snake_pos[1] == bean[1]: + self.beans_position.remove(bean) + self.cur_bean_num -= 1 + return True + return False + + def get_action_dim(self): + action_dim = 1 + for i in range(len(self.joint_action_space[0])): + action_dim *= self.joint_action_space[0][i].n + + return action_dim + + def draw_board(self): + cols = [chr(i) for i in range(65, 65 + self.board_width)] + s = ", ".join(cols) + print(" ", s) + for i in range(self.board_height): + # print(i) + print(chr(i + 65), self.current_state[i]) + + @staticmethod + def _render_board(state, board, colors, unit, fix, extra_info): + im = GridGame._render_board(state, board, colors, unit, fix) + draw = ImageDraw.Draw(im) + # fnt = ImageFont.truetype("Courier.dfont", 16) + fnt = ImageFont.load_default() + for i, pos in zip(count(1), extra_info): + x, y = pos + draw.text( + ((y + 1 / 4) * unit, (x + 1 / 4) * unit), + "#{}".format(i), + font=fnt, + fill=(0, 0, 0), + ) + + return im + + def render_board(self): + extra_info = [tuple(x.headPos) for x in self.players] + im_data = np.array( + SnakeEatBeans._render_board( + self.get_render_data(self.current_state), + self.grid, + self.colors, + self.grid_unit, + self.grid_unit_fix, + extra_info, + ) + ) + return im_data + + @staticmethod + def parse_extra_info(data): + # return eval(re.search(r'({.*})', data['info_after']).group(1)).values() + # d = (eval(eval(data)['snakes_position']).values()) + if isinstance(data, str): + d = eval(data)["snakes_position"] + else: + d = data["snakes_position"] + + return [i[0] for i in d] + + def close(self): + if self.render_mode == "human": + plt.close(self.fig) + + +class Snake: + def __init__(self, player_id, board_width, board_height, init_len): + self.actions = [-2, 2, -1, 1] + self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} + self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右] + self.board_width = board_width + self.board_height = board_height + x = random.randrange(0, board_height) + y = random.randrange(0, board_width) + self.segments = [[x, y]] + self.headPos = self.segments[0] + self.player_id = player_id + self.score = 0 + self.snake_reward = 0 + self.init_len = init_len + + def get_score(self): + return len(self.segments) - self.init_len + + def change_direction(self, act): + if act + self.direction != 0: + self.direction = act + else: + n_direct = random.choice(self.actions) + while n_direct + self.direction == 0: + n_direct = random.choice(self.actions) + self.direction = n_direct + # print("方向不合法,重新生成") + # print("direction", self.actions_name[self.direction]) + + # 超过边界,可以穿越 + def update_position(self, position): + position[0] %= self.board_height + position[1] %= self.board_width + return position + + def move_and_add(self, snakes_position): + cur_head = list(self.headPos) + # 根据方向移动蛇头的坐标 + # 右 + if self.direction == 1: + cur_head[1] += 1 + # 左 + if self.direction == -1: + cur_head[1] -= 1 + # 上 + if self.direction == -2: + cur_head[0] -= 1 + # 下 + if self.direction == 2: + cur_head[0] += 1 + + cur_head = self.update_position(cur_head) + # print("cur head", cur_head) + # print("cur snakes positions", snakes_position) + + self.segments.insert(0, cur_head) + self.headPos = self.segments[0] + return cur_head + + def pop(self): + self.segments.pop() # 在蛇尾减去一格 diff --git a/openrl/envs/snake/snake_3v3.py b/openrl/envs/snake/snake_3v3.py new file mode 100644 index 00000000..78d787ef --- /dev/null +++ b/openrl/envs/snake/snake_3v3.py @@ -0,0 +1,854 @@ +# -*- coding:utf-8 -*- +# 作者:zruizhi +# 创建时间: 2020/7/30 17:24 下午 +# 描述: +import copy +import itertools +import random +import time +from itertools import count + +import numpy as np +from gym import Env, spaces +from PIL import Image, ImageDraw, ImageFont + +from .common import Board, HiddenPrints, SnakePos # TODO: Snake类的重名问题 +from .discrete import Discrete +from .gridgame import GridGame +from .observation import * + + +class SnakeEatBeans(GridGame, GridObservation, DictObservation): + def __init__(self, all_args, env_id): + self.all_args = all_args + conf = { + "class_literal": "SnakeEatBeans", + "n_player": 6, + "board_width": 20, + "board_height": 10, + "channels": 15, + "cell_range": 8, + "n_beans": 5, + "max_step": 200, + "game_name": "snakes", + "is_obs_continuous": False, + "is_act_continuous": False, + "agent_nums": [3, 3], + "obs_type": ["dict", "dict"], + "save_interval": 100, + "save_path": "../../replay/snake_3v3/replay_{}.gif", + } + self.terminate_flg = False + colors = conf.get("colors", [(255, 255, 255), (255, 140, 0)]) + super(SnakeEatBeans, self).__init__(conf, colors) + # 0: 没有 1:食物 2-n_player+1:各玩家蛇身 + self.n_cell_type = self.n_player + 2 + self.step_cnt = 1 + self.n_beans = int(conf["n_beans"]) + # 方向[-2,2,-1,1]分别表示[上,下,左,右] + self.actions = [-2, 2, -1, 1] + self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} + self.snakes_position = {} + self.players = [] + self.cur_bean_num = 0 + self.beans_position = [] + # 1<= init_len <= 3 + self.init_len = 3 + self.current_state = self.init_state() + self.all_observes = self.get_all_observes() + if self.n_player * self.init_len > self.board_height * self.board_width: + raise Exception( + "玩家数量过多:%d,超出board范围:%d,%d" + % (self.n_player, self.board_width, self.board_height) + ) + + self.input_dimension = self.board_width * self.board_height + self.action_dim = self.get_action_dim() + self.channels = conf["channels"] + + self.num_agents = conf["agent_nums"][0] + self.num_enemys = conf["agent_nums"][1] + + self.observation_space = [ + spaces.Box( + low=-np.inf, + high=-np.inf, + shape=(self.channels, self.board_width, self.board_height), + dtype=np.float32, + ) + ] + self.share_observation_space = [] + self.share_observation_space = [ + spaces.Box( + low=-np.inf, + high=+np.inf, + shape=(self.channels, self.board_width, self.board_height), + dtype=np.float32, + ) + ] + self.action_space = [Discrete(4) for _ in range(self.n_player)] + self.save_interval = conf["save_interval"] + self.save_path = conf["save_path"] + self.episode = 0 + self.render = all_args.save_replay + self.img_list = [] + self.env_id = env_id + + def seed(self, seed=None): + if seed is None: + np.random.seed(1) + else: + np.random.seed(seed) + + def check_win(self): + flg = self.won.index(max(self.won)) + 2 + return flg + + def get_grid_observation(self, current_state, player_id, info_before): + return current_state + + def get_dict_observation(self, current_state, player_id, info_before): + key_info = {1: self.beans_position} + for i in range(self.n_player): + snake = self.players[i] + key_info[snake.player_id] = snake.segments + # key_info['state_map'] = current_state + key_info["board_width"] = self.board_width + key_info["board_height"] = self.board_height + key_info["last_direction"] = ( + info_before.get("directions") if isinstance(info_before, dict) else None + ) + key_info["controlled_snake_index"] = player_id + + return key_info + + def set_action_space(self): + action_space = [[Discrete(4)] for _ in range(self.n_player)] + return action_space + + def reset(self): + self.step_cnt = 1 + self.snakes_position = ( + {} + ) # 格式类似于{1: [[3, 1], [4, 3], [1, 2], [0, 6], [3, 3]], 2: [[3, 0], [3, 7], [3, 6]], 3: [[2, 7], [1, 7], [0, 7]]} + self.players = [] + self.cur_bean_num = 0 + self.beans_position = [] + self.current_state = self.init_state() + self.all_observes = self.get_all_observes() + self.terminate_flg = False + self.img_list = [] + self.episode += 1 + + # available actions + left_avail_actions = np.ones([self.num_agents, self.action_dim]) + right_avail_actions = np.ones([self.num_enemys, self.action_dim]) + avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0) + # process obs + board = [] + for i in range(self.n_player): + board.append([self.get_board(self.all_observes[i])]) + + board_ = np.concatenate(board) + obs = [] + for raw_obs in self.all_observes: + obs.append([self.raw2vec(raw_obs)]) + obs_ = np.concatenate(obs) + obs_ = np.concatenate((obs_, board_), axis=1) + + share_obs = np.repeat(np.expand_dims(obs_[0], axis=0), 6, 0) + + return obs_, share_obs, avail_actions # obs:(n_player, 288) + + # return self.all_observes + + def step(self, joint_action): + info_before = self.step_before_info() + joint_action = np.expand_dims(joint_action, 1) + all_observes, info_after = self.get_next_state(joint_action) + done = self.is_terminal() + reward = self.get_reward(joint_action) + left_avail_actions = np.ones([self.num_agents, self.action_dim]) + right_avail_actions = np.ones([self.num_enemys, self.action_dim]) + avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0) + + board = [] + for i in range(self.n_player): + board.append([self.get_board(all_observes[i])]) + + board_ = np.concatenate(board) + + obs = [] + + for raw_obs in all_observes: + obs.append([self.raw2vec(raw_obs)]) # obs:[[(14, 20, 10)], [], ..., []] + + obs_ = np.concatenate(obs) # (n_player, channels, width, height) + obs_ = np.concatenate((obs_, board_), axis=1) + + share_obs = np.repeat(np.expand_dims(obs_[0], axis=0), 6, 0) + + if done: + reward = self.get_final_reward(reward) + + rewards = np.expand_dims(np.array(reward), axis=1) + + dones = [done] * self.n_player + infos = [info_after] * self.n_player + + if self.render and self.episode % self.save_interval == 0 and self.env_id == 0: + img = self.render_board() + img_pil = Image.fromarray(img) + self.img_list.append(img_pil) + + if done: + self.img_list[0].save( + self.save_path.format(self.episode), + save_all=True, + append_images=self.img_list[1:], + duration=400, + ) + print("save replay gif to" + self.save_path.format(self.episode)) + + return obs_, share_obs, rewards, dones, infos, avail_actions + # return all_observes, reward, done, info_before, info_after + + # obs: 0 空白 1 豆子 2 我方蛇头 3 我方蛇身 4-5 友方蛇头 6-7 友方蛇身 8-10 敌方蛇头 11-13 敌方蛇身 + def raw2vec(self, raw_obs): + control_index = raw_obs["controlled_snake_index"] + width = raw_obs["board_width"] + height = raw_obs["board_height"] + beans = raw_obs[1] + pos = raw_obs[control_index] + + obs = np.zeros(width * height, dtype=int) + head_h, head_w = pos[0] + obs[head_h * width + head_w] = 2 + + for bean in beans: + h, w = bean + obs[h * width + w] = 1 + + for p in pos[1:]: + h, w = p + obs[h * width + w] = 3 + + if control_index == 2: + h1, w1 = raw_obs[3][0] + h2, w2 = raw_obs[4][0] + obs[h1 * width + w1] = 4 + obs[h2 * width + w2] = 5 + for p in raw_obs[3][1:]: + h, w = p + obs[h * width + w] = 6 + for p in raw_obs[4][1:]: + h, w = p + obs[h * width + w] = 7 + for i in range(self.num_agents + 2, self.n_player + 2): + h, w = raw_obs[i][0] + obs[h * width + w] = i + 3 + for p in raw_obs[i][1:]: + h, w = p + obs[h * width + w] = i + 6 + elif control_index == 3: + h1, w1 = raw_obs[2][0] + h2, w2 = raw_obs[4][0] + obs[h1 * width + w1] = 4 + obs[h2 * width + w2] = 5 + for p in raw_obs[2][1:]: + h, w = p + obs[h * width + w] = 6 + for p in raw_obs[4][1:]: + h, w = p + obs[h * width + w] = 7 + for i in range(self.num_agents + 2, self.n_player + 2): + h, w = raw_obs[i][0] + obs[h * width + w] = i + 3 + for p in raw_obs[i][1:]: + h, w = p + obs[h * width + w] = i + 6 + elif control_index == 4: + h1, w1 = raw_obs[2][0] + h2, w2 = raw_obs[3][0] + obs[h1 * width + w1] = 4 + obs[h2 * width + w2] = 5 + for p in raw_obs[2][1:]: + h, w = p + obs[h * width + w] = 6 + for p in raw_obs[3][1:]: + h, w = p + obs[h * width + w] = 7 + for i in range(self.num_agents + 2, self.n_player + 2): + h, w = raw_obs[i][0] + obs[h * width + w] = i + 3 + for p in raw_obs[i][1:]: + h, w = p + obs[h * width + w] = i + 6 + elif control_index == 5: + h1, w1 = raw_obs[6][0] + h2, w2 = raw_obs[7][0] + obs[h1 * width + w1] = 4 + obs[h2 * width + w2] = 5 + for p in raw_obs[6][1:]: + h, w = p + obs[h * width + w] = 6 + for p in raw_obs[7][1:]: + h, w = p + obs[h * width + w] = 7 + for i in range(2, self.num_agents + 2): + h, w = raw_obs[i][0] + obs[h * width + w] = i + 6 + for p in raw_obs[i][1:]: + h, w = p + obs[h * width + w] = i + 9 + elif control_index == 6: + h1, w1 = raw_obs[5][0] + h2, w2 = raw_obs[7][0] + obs[h1 * width + w1] = 4 + obs[h2 * width + w2] = 5 + for p in raw_obs[5][1:]: + h, w = p + obs[h * width + w] = 6 + for p in raw_obs[7][1:]: + h, w = p + obs[h * width + w] = 7 + for i in range(2, self.num_agents + 2): + h, w = raw_obs[i][0] + obs[h * width + w] = i + 6 + for p in raw_obs[i][1:]: + h, w = p + obs[h * width + w] = i + 9 + else: + h1, w1 = raw_obs[5][0] + h2, w2 = raw_obs[6][0] + obs[h1 * width + w1] = 4 + obs[h2 * width + w2] = 5 + for p in raw_obs[5][1:]: + h, w = p + obs[h * width + w] = 6 + for p in raw_obs[6][1:]: + h, w = p + obs[h * width + w] = 7 + for i in range(2, self.num_agents + 2): + h, w = raw_obs[i][0] + obs[h * width + w] = i + 6 + for p in raw_obs[i][1:]: + h, w = p + obs[h * width + w] = i + 9 + + obs_ = np.zeros(width * height * (self.channels - 1), dtype=int) + for i in range(width * height): + obs_[i * (self.channels - 1) + obs[i]] = ( + 1 # channels的最后一维是territory matrix, 此处不生成, 要减去 + ) + obs_ = obs_.reshape( + height, width, (self.channels - 1) + ) # (height, width, channels-1 ) + obs_ = obs_.transpose((2, 1, 0)) + + return obs_ + + def get_board(self, observation_list): + observation_len = len(observation_list.keys()) + teams = None + teams = [[0, 1, 2], [3, 4, 5]] # 3v3 + teams_count = len(teams) + snakes_count = sum([len(_) for _ in teams]) + + # read observation + obs = observation_list.copy() + board_height = obs["board_height"] # 10 + board_width = obs["board_width"] # 20 + # print("obs['controlled_snake_index'] is ", obs['controlled_snake_index']) + ctrl_agent_index = obs["controlled_snake_index"] - 2 # 0, 1, 2, 3, 4, 5 + # last_directions = obs['last_direction'] # ['up', 'left', 'down', 'left', 'left', 'left'] + beans_positions = obs[1] # e.g.[[7, 15], [4, 14], [5, 12], [4, 12], [5, 7]] + snakes = { + key - 2: SnakePos(obs[key], board_height, board_width, beans_positions) + for key in obs.keys() & {_ + 2 for _ in range(snakes_count)} + } # &: intersection + team_indexes = [_ for _ in teams if ctrl_agent_index in _][0] + + init_board = Board(board_height, board_width, snakes, beans_positions, teams) + bd = copy.deepcopy(init_board) + + with HiddenPrints(): + while not all( + _ == [] for _ in bd.open.values() + ): # loop until all values in open are empty list + bd.step() + + board = np.array(bd.board).transpose() + board = np.expand_dims(board, axis=0) + return board + + def init_state(self): + for i in range(self.n_player): + s = Snake(i + 2, self.board_width, self.board_height, self.init_len) + s_len = 1 + while s_len < self.init_len: + if s_len == 1 and i > 0: + origin_hit = self.is_hit(s.headPos, self.snakes_position) + else: + origin_hit = 0 + cur_head = s.move_and_add(self.snakes_position) + cur_hit = self.is_hit(cur_head, self.snakes_position) or self.is_hit( + cur_head, {i: s.segments[1:]} + ) + if origin_hit or cur_hit: + x = random.randrange(0, self.board_height) + y = random.randrange(0, self.board_width) + s.headPos = [x, y] + s.segments = [s.headPos] + s.direction = random.choice(self.actions) + s_len = 1 + else: + s_len += 1 + self.snakes_position[s.player_id] = s.segments + self.players.append(s) + + self.generate_beans() + self.init_info = { + "snakes_position": [ + list(v) + for k, v in sorted( + self.snakes_position.items(), key=lambda item: item[0] + ) + ], + "beans_position": list(self.beans_position), + } + directs = [] + for i in range(len(self.players)): + s = self.players[i] + directs.append(self.actions_name[s.direction]) + self.init_info["directions"] = directs + + return self.update_state() + + def update_state(self): + next_state = [ + [[0] * self.cell_dim for _ in range(self.board_width)] + for _ in range(self.board_height) + ] + for i in range(self.n_player): + snake = self.players[i] + for pos in snake.segments: + next_state[pos[0]][pos[1]][0] = i + 2 + + for pos in self.beans_position: + next_state[pos[0]][pos[1]][0] = 1 + + return next_state + + def step_before_info(self, info=""): + directs = [] + for i in range(len(self.players)): + s = self.players[i] + directs.append(self.actions_name[s.direction]) + info = {"directions": directs} + + return info + + def is_hit(self, cur_head, snakes_position): + is_hit = False + for k, v in snakes_position.items(): + for pos in v: + if cur_head == pos: + is_hit = True + # print("hit:", cur_head, snakes_position) + break + if is_hit: + break + + return is_hit + + def generate_beans(self): + all_valid_positions = set( + itertools.product(range(0, self.board_height), range(0, self.board_width)) + ) + all_valid_positions = all_valid_positions - set(map(tuple, self.beans_position)) + for positions in self.snakes_position.values(): + all_valid_positions = all_valid_positions - set(map(tuple, positions)) + + left_bean_num = self.n_beans - self.cur_bean_num + all_valid_positions = np.array(list(all_valid_positions)) + left_valid_positions = len(all_valid_positions) + + new_bean_num = ( + left_bean_num + if left_valid_positions > left_bean_num + else left_valid_positions + ) + + if left_valid_positions > 0: + new_bean_positions_idx = np.random.choice( + left_valid_positions, size=new_bean_num, replace=False + ) + new_bean_positions = all_valid_positions[new_bean_positions_idx] + else: + new_bean_positions = [] + + for new_bean_pos in new_bean_positions: + self.beans_position.append(list(new_bean_pos)) + self.cur_bean_num += 1 + + def get_all_observes(self, before_info=""): + self.all_observes = [] + for i in range(self.n_player): + each_obs = self.get_dict_observation(self.current_state, i + 2, before_info) + self.all_observes.append(each_obs) + + return self.all_observes + + def get_next_state(self, all_action): + before_info = self.step_before_info() + not_valid = self.is_not_valid_action(all_action) + if not not_valid: + # 各玩家行动 + # print("current_state", self.current_state) + eat_snakes = [0] * self.n_player + ally_reward = 0 + enemy_reward = 0 + for i in range(self.n_player): # 判断是否吃到了豆子 + snake = self.players[i] + act = self.actions[np.argmax(all_action[i][0])] + # print(snake.player_id, "此轮的动作为:", self.actions_name[act]) + snake.change_direction(act) + snake.move_and_add(self.snakes_position) # 更新snake.segment + if self.be_eaten(snake.headPos): # @yanxue + snake.snake_reward = 1 + eat_snakes[i] = 1 + else: + snake.snake_reward = 0 + snake.pop() + # print(snake.player_id, snake.segments) # @yanxue + snake_position = [[-1] * self.board_width for _ in range(self.board_height)] + re_generatelist = [0] * self.n_player + for i in range(self.n_player): # 判断是否相撞 + snake = self.players[i] + segment = snake.segments + for j in range(len(segment)): + x = segment[j][0] + y = segment[j][1] + if snake_position[x][y] != -1: + if j == 0: # 撞头 + re_generatelist[i] = 1 + compare_snake = self.players[snake_position[x][y]] + if [x, y] == compare_snake.segments[0]: # 两头相撞won + re_generatelist[snake_position[x][y]] = 1 + else: + snake_position[x][y] = i + for i in range(self.n_player): + snake = self.players[i] + if re_generatelist[i] == 1: + if eat_snakes[i] == 1: + snake.snake_reward = ( + self.init_len - len(snake.segments) + 1 + ) # 身体越长,惩罚越大 + else: + snake.snake_reward = self.init_len - len(snake.segments) + snake.segments = [] + + for i in range(self.num_agents): + ally_reward += self.players[i].snake_reward + for i in range(self.num_enemys): + enemy_reward += self.players[i + self.num_agents].snake_reward + alpha = 0.8 + for i in range(self.num_agents): + self.players[i].snake_reward = ( + self.players[i].snake_reward - enemy_reward / 3 + ) * alpha + ally_reward / 3 * (1 - alpha) + for i in range(self.num_agents, self.n_player): + self.players[i].snake_reward = ( + self.players[i].snake_reward - ally_reward / 3 + ) * alpha + enemy_reward / 3 * (1 - alpha) + + for i in range(self.n_player): + snake = self.players[i] + if re_generatelist[i] == 1: + snake = self.clear_or_regenerate(snake) + self.snakes_position[snake.player_id] = snake.segments + snake.score = snake.get_score() + # yanxue add + # 更新状态 + self.generate_beans() + + next_state = self.update_state() + self.current_state = next_state + self.step_cnt += 1 + + self.won = [0] * self.n_player + + for i in range(self.n_player): + s = self.players[i] + self.won[i] = s.score + info_after = {} + info_after["snakes_position"] = [ + list(v) + for k, v in sorted( + self.snakes_position.items(), key=lambda item: item[0] + ) + ] + info_after["beans_position"] = list(self.beans_position) + info_after["hit"] = re_generatelist + info_after["score"] = self.won + self.all_observes = self.get_all_observes(before_info) + + return self.all_observes, info_after + + def clear_or_regenerate(self, snake): + direct_x = [0, 1, -1, 0] + direct_y = [1, 0, 0, -1] + snake.segments = [] + snake.score = 0 + grid = self.get_render_data(self.update_state()) + + def can_regenerate(): + for x in range(self.board_height): + for y in range(self.board_width): + if grid[x][y] == 0: + q = [] + q.append([x, y]) + seg = [] + while q: + cur = q.pop(0) + if cur not in seg: + seg.append(cur) + for i in range(4): + nx = (direct_x[i] + cur[0]) % self.board_height + ny = (direct_y[i] + cur[1]) % self.board_width + # if nx < 0 or nx >= self.board_height or ny < 0 or ny >= self.board_width: + # continue + if grid[nx][ny] == 0 and [nx, ny] not in q: + grid[nx][ny] = 1 + q.append([nx, ny]) + if len(seg) == self.init_len: + # print("regenerate") + if len(seg) < 3: + snake.direction = random.choice(self.actions) + elif len(seg) == 3: + mid = ( + [seg[1][0], seg[2][1]], + [seg[2][0], seg[1][1]], + ) + if seg[0] in mid: + seg[0], seg[1] = seg[1], seg[0] + snake.segments = seg + snake.headPos = seg[0] + if seg[0][0] == seg[1][0]: + # 右 + if seg[0][1] > seg[1][1]: + snake.direction = 1 + # 左 + else: + snake.direction = -1 + elif seg[0][1] == seg[1][1]: + # 下 + if seg[0][0] > seg[1][0]: + snake.direction = 2 + # 上 + else: + snake.direction = -2 + # print("re head", snake.headPos) # 输出重新生成的蛇 + # print("re snakes segments", snake.segments) + return True + # print("clear") + return False + + flg = can_regenerate() + if not flg: + self.terminate_flg = True + # print(self.terminate_flg) + return snake + + def is_not_valid_action(self, all_action): + not_valid = 0 + if len(all_action) != self.n_player: + raise Exception("all action 维度不正确!", len(all_action)) + + for i in range(self.n_player): + if len(all_action[i][0]) != 4: + raise Exception("玩家%d joint action维度不正确!" % i, all_action[i]) + return not_valid + + def get_reward(self, all_action): + r = [0] * self.n_player + for i in range(self.n_player): + r[i] = self.players[i].snake_reward + self.n_return[i] += r[i] + # print("score:", self.won) + return r + + def get_final_reward(self, reward): + ally_reward = reward[0] + reward[1] + reward[2] + enemy_reward = reward[3] + reward[4] + reward[5] + if ally_reward > enemy_reward: + reward[0] += 10 + reward[1] += 10 + reward[2] += 10 + reward[3] -= 10 + reward[4] -= 10 + reward[5] -= 10 + elif ally_reward < enemy_reward: + reward[3] += 10 + reward[4] += 10 + reward[5] += 10 + reward[0] -= 10 + reward[1] -= 10 + reward[2] -= 10 + return reward + + def is_terminal(self): + all_member = self.n_beans + # all_member = len(self.beans_position) + for s in self.players: + all_member += len(s.segments) + is_done = ( + self.step_cnt > self.max_step + or all_member > self.board_height * self.board_width + ) + + return is_done or self.terminate_flg + + def encode(self, actions): + joint_action = self.init_action_space() + if len(actions) != self.n_player: + raise Exception("action输入维度不正确!", len(actions)) + for i in range(self.n_player): + joint_action[i][0][int(actions[i])] = 1 + return joint_action + + def get_terminal_actions(self): + print("请输入%d个玩家的动作方向[0-3](上下左右),空格隔开:" % self.n_player) + cur = input() + actions = cur.split(" ") + return self.encode(actions) + + def be_eaten(self, snake_pos): + for bean in self.beans_position: + if snake_pos[0] == bean[0] and snake_pos[1] == bean[1]: + self.beans_position.remove(bean) + self.cur_bean_num -= 1 + return True + return False + + def get_action_dim(self): + action_dim = 1 + for i in range(len(self.joint_action_space[0])): + action_dim *= self.joint_action_space[0][i].n + + return action_dim + + def draw_board(self): + cols = [chr(i) for i in range(65, 65 + self.board_width)] + s = ", ".join(cols) + print(" ", s) + for i in range(self.board_height): + # print(i) + print(chr(i + 65), self.current_state[i]) + + @staticmethod + def _render_board(state, board, colors, unit, fix, extra_info): + im = GridGame._render_board(state, board, colors, unit, fix) + draw = ImageDraw.Draw(im) + # fnt = ImageFont.truetype("Courier.dfont", 16) + fnt = ImageFont.load_default() + for i, pos in zip(count(1), extra_info): + x, y = pos + draw.text( + ((y + 1 / 4) * unit, (x + 1 / 4) * unit), + "#{}".format(i), + font=fnt, + fill=(0, 0, 0), + ) + + return im + + def render_board(self): + extra_info = [tuple(x.headPos) for x in self.players] + im_data = np.array( + SnakeEatBeans._render_board( + self.get_render_data(self.current_state), + self.grid, + self.colors, + self.grid_unit, + self.grid_unit_fix, + extra_info, + ) + ) + return im_data + + @staticmethod + def parse_extra_info(data): + # return eval(re.search(r'({.*})', data['info_after']).group(1)).values() + # d = (eval(eval(data)['snakes_position']).values()) + if isinstance(data, str): + d = eval(data)["snakes_position"] + else: + d = data["snakes_position"] + + return [i[0] for i in d] + + +class Snake: + def __init__(self, player_id, board_width, board_height, init_len): + self.actions = [-2, 2, -1, 1] + self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} + self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右] + self.board_width = board_width + self.board_height = board_height + x = random.randrange(0, board_height) + y = random.randrange(0, board_width) + self.segments = [[x, y]] + self.headPos = self.segments[0] + self.player_id = player_id + self.score = 0 + self.snake_reward = 0 + self.init_len = init_len + + def get_score(self): + return len(self.segments) - self.init_len + + def change_direction(self, act): + if act + self.direction != 0: + self.direction = act + else: + n_direct = random.choice(self.actions) + while n_direct + self.direction == 0: + n_direct = random.choice(self.actions) + self.direction = n_direct + # print("方向不合法,重新生成") + # print("direction", self.actions_name[self.direction]) + + # 超过边界,可以穿越 + def update_position(self, position): + position[0] %= self.board_height + position[1] %= self.board_width + return position + + def move_and_add(self, snakes_position): + cur_head = list(self.headPos) + # 根据方向移动蛇头的坐标 + # 右 + if self.direction == 1: + cur_head[1] += 1 + # 左 + if self.direction == -1: + cur_head[1] -= 1 + # 上 + if self.direction == -2: + cur_head[0] -= 1 + # 下 + if self.direction == 2: + cur_head[0] += 1 + + cur_head = self.update_position(cur_head) + # print("cur head", cur_head) + # print("cur snakes positions", snakes_position) + + self.segments.insert(0, cur_head) + self.headPos = self.segments[0] + return cur_head + + def pop(self): + self.segments.pop() # 在蛇尾减去一格 diff --git a/openrl/envs/snake/snake_pettingzoo.py b/openrl/envs/snake/snake_pettingzoo.py new file mode 100644 index 00000000..a9c18c76 --- /dev/null +++ b/openrl/envs/snake/snake_pettingzoo.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import functools +from copy import deepcopy +from typing import Optional + +import numpy as np +from gymnasium import spaces +from pettingzoo import AECEnv +from pettingzoo.utils import agent_selector + +from openrl.envs.snake.snake import SnakeEatBeans + +NONE = 4 + + +class SnakeEatBeansAECEnv(AECEnv): + metadata = {"render.modes": ["human"], "name": "SnakeEatBeans"} + + def __init__(self, render_mode: Optional[str] = None): + self.env = SnakeEatBeans(render_mode) + + self.agent_name_mapping = dict( + zip(self.possible_agents, list(range(len(self.possible_agents)))) + ) + self._action_spaces = { + agent: spaces.Discrete(4) for agent in self.possible_agents + } + self._observation_spaces = { + agent: spaces.Box(low=-np.inf, high=np.inf, shape=(288,), dtype=np.float32) + for agent in self.possible_agents + } + + self.agents = self.possible_agents[:] + + self.observations = {agent: NONE for agent in self.agents} + self.raw_obs, self.raw_reward, self.raw_done, self.raw_info = ( + None, + None, + None, + None, + ) + + @functools.lru_cache(maxsize=None) + def observation_space(self, agent): + return deepcopy(self._observation_spaces[agent]) + + @functools.lru_cache(maxsize=None) + def action_space(self, agent): + return deepcopy(self._action_spaces[agent]) + + def observe(self, agent): + return self.raw_obs[self.agent_name_mapping[agent]] + + def reset( + self, + seed: Optional[int] = None, + options: Optional[dict] = None, + ): + if seed is not None: + self.env.seed(seed) + self.agents = self.possible_agents[:] + self.rewards = {agent: 0 for agent in self.agents} + self._cumulative_rewards = {agent: 0 for agent in self.agents} + self.terminations = {agent: False for agent in self.agents} + self.truncations = {agent: False for agent in self.agents} + self.infos = {agent: {} for agent in self.agents} + self.state = {agent: NONE for agent in self.agents} + self.observations = {agent: NONE for agent in self.agents} + + self.raw_obs, self.raw_info = self.env.reset() + + self._agent_selector = agent_selector(self.agents) + self.agent_selection = self._agent_selector.next() + + def step(self, action): + agent = self.agent_selection + self._cumulative_rewards[agent] = 0 + self.state[self.agent_selection] = action + if self._agent_selector.is_last(): + joint_action = [self.state[agent] for agent in self.agents] + self.raw_obs, self.raw_reward, self.raw_done, self.raw_info = self.env.step( + joint_action + ) + + self.rewards = { + agent: self.raw_reward[i] for i, agent in enumerate(self.agents) + } + + if np.any(self.raw_done): + for key in self.terminations: + self.terminations[key] = True + else: + self.state[self.agents[1 - self.agent_name_mapping[agent]]] = NONE + self._clear_rewards() + + # selects the next agent. + self.agent_selection = self._agent_selector.next() + self._accumulate_rewards() + + def render(self): + img = self.env.render() + return img + + def close(self): + self.env.close() + + @property + def possible_agents(self): + return ["player_" + str(i) for i in range(self.env.n_player)] + + @property + def num_agents(self): + return len(self.possible_agents) diff --git a/openrl/envs/snake/space.py b/openrl/envs/snake/space.py new file mode 100644 index 00000000..672e2367 --- /dev/null +++ b/openrl/envs/snake/space.py @@ -0,0 +1,63 @@ +from gym.utils import seeding + + +class Space(object): + """Defines the observation and action spaces, so you can write generic + code that applies to any Env. For example, you can choose a random + action. + WARNING - Custom observation & action spaces can inherit from the `Space` + class. However, most use-cases should be covered by the existing space + classes (e.g. `Box`, `Discrete`, etc...), and container classes (`Tuple` & + `Dict`). Note that parametrized probability distributions (through the + `sample()` method), and batching functions (in `gym.vector.VectorEnv`), are + only well-defined for instances of spaces provided in gym by default. + Moreover, some implementations of Reinforcement Learning algorithms might + not handle custom spaces properly. Use custom spaces with care. + """ + + def __init__(self, shape=None, dtype=None): + import numpy as np # takes about 300-400ms to import, so we load lazily + + self.shape = None if shape is None else tuple(shape) + self.dtype = None if dtype is None else np.dtype(dtype) + self._np_random = None + + @property + def np_random(self): + """Lazily seed the rng since this is expensive and only needed if + sampling from this space. + """ + if self._np_random is None: + self.seed() + + return self._np_random + + def sample(self): + """Randomly sample an element of this space. Can be + uniform or non-uniform sampling based on boundedness of space.""" + raise NotImplementedError + + def seed(self, seed=None): + """Seed the PRNG of this space.""" + self._np_random, seed = seeding.np_random(seed) + return [seed] + + def contains(self, x): + """ + Return boolean specifying if x is a valid + member of this space + """ + raise NotImplementedError + + def __contains__(self, x): + return self.contains(x) + + def to_jsonable(self, sample_n): + """Convert a batch of samples from this space to a JSONable data type.""" + # By default, assume identity is JSONable + return sample_n + + def from_jsonable(self, sample_n): + """Convert a JSONable data type to a batch of samples from this space.""" + # By default, assume identity is JSONable + return diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py index 1ca95674..7c620aee 100644 --- a/openrl/envs/vec_env/async_venv.py +++ b/openrl/envs/vec_env/async_venv.py @@ -233,8 +233,10 @@ def reset_send( if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `reset_send` while waiting for a pending call to" - f" `{self._state.value}` to complete", + ( + "Calling `reset_send` while waiting for a pending call to" + f" `{self._state.value}` to complete" + ), self._state.value, ) @@ -326,8 +328,10 @@ def step_send(self, actions: np.ndarray): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `step_send` while waiting for a pending call to" - f" `{self._state.value}` to complete.", + ( + "Calling `step_send` while waiting for a pending call to" + f" `{self._state.value}` to complete." + ), self._state.value, ) @@ -571,8 +575,10 @@ def call_send(self, name: str, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `call_send` while waiting " - f"for a pending call to `{self._state.value}` to complete.", + ( + "Calling `call_send` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), str(self._state.value), ) @@ -629,8 +635,10 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `exec_func_send` while waiting " - f"for a pending call to `{self._state.value}` to complete.", + ( + "Calling `exec_func_send` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), str(self._state.value), ) @@ -707,8 +715,10 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]): if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `set_attr` while waiting " - f"for a pending call to `{self._state.value}` to complete.", + ( + "Calling `set_attr` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), str(self._state.value), ) diff --git a/openrl/envs/wrappers/pettingzoo_wrappers.py b/openrl/envs/wrappers/pettingzoo_wrappers.py index 687384a4..647c13be 100644 --- a/openrl/envs/wrappers/pettingzoo_wrappers.py +++ b/openrl/envs/wrappers/pettingzoo_wrappers.py @@ -24,11 +24,12 @@ class SeedEnv(BaseWrapper): def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): super().reset(seed=seed, options=options) - - for i, space in enumerate( - list(self.action_spaces.values()) + list(self.observation_spaces.values()) - ): - space.seed(seed + i * 7891) + if seed is not None: + for i, space in enumerate( + list(self.action_spaces.values()) + + list(self.observation_spaces.values()) + ): + space.seed(seed + i * 7891) class RecordWinner(BaseWrapper): diff --git a/openrl/envs/wrappers/util.py b/openrl/envs/wrappers/util.py index a7bf6379..a0a97576 100644 --- a/openrl/envs/wrappers/util.py +++ b/openrl/envs/wrappers/util.py @@ -38,6 +38,8 @@ def nest_expand_dim(input: Any) -> Any: return [input] elif isinstance(input, np.int64): return [input] + elif input is None: + return [input] else: raise NotImplementedError("Not support type: {}".format(type(input))) diff --git a/openrl/modules/networks/utils/nlp/hf_generation_utils.py b/openrl/modules/networks/utils/nlp/hf_generation_utils.py index 8a44d8c7..37d80875 100644 --- a/openrl/modules/networks/utils/nlp/hf_generation_utils.py +++ b/openrl/modules/networks/utils/nlp/hf_generation_utils.py @@ -1359,9 +1359,11 @@ def generate( elif max_length is not None and max_new_tokens is not None: # Both are set, this is odd, raise a warning warnings.warn( - "Both `max_length` and `max_new_tokens` have been set " - f"but they serve the same purpose. `max_length` {max_length} " - f"will take priority over `max_new_tokens` {max_new_tokens}.", + ( + "Both `max_length` and `max_new_tokens` have been set " + f"but they serve the same purpose. `max_length` {max_length} " + f"will take priority over `max_new_tokens` {max_new_tokens}." + ), UserWarning, ) # default to config if still None @@ -1847,9 +1849,11 @@ def greedy_search( ) if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`" - " instead.", + ( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`" + " instead." + ), UserWarning, ) stopping_criteria = validate_stopping_criteria( @@ -2147,9 +2151,11 @@ def sample( ) if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead.", + ( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" + " instead." + ), UserWarning, ) stopping_criteria = validate_stopping_criteria( @@ -2453,9 +2459,11 @@ def beam_search( ) if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead.", + ( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" + " instead." + ), UserWarning, ) stopping_criteria = validate_stopping_criteria( @@ -2463,8 +2471,10 @@ def beam_search( ) if len(stopping_criteria) == 0: warnings.warn( - "You don't have defined any stopping_criteria, this will likely" - " loop forever", + ( + "You don't have defined any stopping_criteria, this will likely" + " loop forever" + ), UserWarning, ) pad_token_id = ( @@ -2857,9 +2867,11 @@ def beam_sample( ) if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead.", + ( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" + " instead." + ), UserWarning, ) stopping_criteria = validate_stopping_criteria( @@ -3240,9 +3252,11 @@ def group_beam_search( ) if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead.", + ( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" + " instead." + ), UserWarning, ) stopping_criteria = validate_stopping_criteria( @@ -3686,9 +3700,11 @@ def constrained_beam_search( ) if max_length is not None: warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead.", + ( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" + " instead." + ), UserWarning, ) stopping_criteria = validate_stopping_criteria( @@ -3696,8 +3712,10 @@ def constrained_beam_search( ) if len(stopping_criteria) == 0: warnings.warn( - "You don't have defined any stopping_criteria, this will likely" - " loop forever", + ( + "You don't have defined any stopping_criteria, this will likely" + " loop forever" + ), UserWarning, ) pad_token_id = ( diff --git a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py index cc23c1ff..38050cc7 100644 --- a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py +++ b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py @@ -58,7 +58,7 @@ def action_space( if self.self_player is None: self.env.reset() self.self_player = self.np_random.choice(self.env.agents) - return self.env.action_spaces[self.self_player] + return self.env.action_space(self.self_player) return self._action_space @property @@ -70,7 +70,9 @@ def observation_space( if self.self_player is None: self.env.reset() self.self_player = self.np_random.choice(self.env.agents) + return self.env.observation_spaces[self.self_player] + return self._observation_space @abstractmethod diff --git a/openrl/selfplay/wrappers/random_opponent_wrapper.py b/openrl/selfplay/wrappers/random_opponent_wrapper.py index 96d562e0..e429b605 100644 --- a/openrl/selfplay/wrappers/random_opponent_wrapper.py +++ b/openrl/selfplay/wrappers/random_opponent_wrapper.py @@ -29,6 +29,8 @@ class RandomOpponentWrapper(BaseMultiPlayerWrapper): def get_opponent_action( self, player_name, observation, reward, termination, truncation, info ): - mask = observation["action_mask"] + mask = None + if "action_mask" in observation: + mask = observation["action_mask"] action = self.env.action_space(player_name).sample(mask) return action diff --git a/openrl/utils/callbacks/checkpoint_callback.py b/openrl/utils/callbacks/checkpoint_callback.py index 56bf31b8..a4b3f5b6 100644 --- a/openrl/utils/callbacks/checkpoint_callback.py +++ b/openrl/utils/callbacks/checkpoint_callback.py @@ -72,7 +72,9 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st """ return os.path.join( self.save_path, - f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}", + ( + f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}" + ), ) def _on_step(self) -> bool: diff --git a/openrl/utils/evaluation.py b/openrl/utils/evaluation.py index 391ba10f..d603daa5 100644 --- a/openrl/utils/evaluation.py +++ b/openrl/utils/evaluation.py @@ -68,10 +68,12 @@ def evaluate_policy( if not is_monitor_wrapped and warn: warnings.warn( - "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" - " may result in reporting modified episode lengths and rewards, if" - " other wrappers happen to modify these. Consider wrapping environment" - " first with ``Monitor`` wrapper.", + ( + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" + " may result in reporting modified episode lengths and rewards, if" + " other wrappers happen to modify these. Consider wrapping environment" + " first with ``Monitor`` wrapper." + ), UserWarning, )