diff --git a/Gallery.md b/Gallery.md index 8d18008d..8009d459 100644 --- a/Gallery.md +++ b/Gallery.md @@ -41,6 +41,7 @@ Users are also welcome to contribute their own training examples and demos to th | [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) | | [GAIL](https://arxiv.org/abs/1606.03476) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/gail/) | | [Behavior Cloning](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/behavior_cloning/) | +| [A2C](https://arxiv.org/abs/1602.01783) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) | | Self-Play | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) | | [DQN](https://arxiv.org/abs/1312.5602) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![value](https://img.shields.io/badge/-value-orange) ![offpolicy](https://img.shields.io/badge/-offpolicy-blue) | [code](./examples/toy_env) [code](./examples/gridworld/) | | [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/mpe/) | diff --git a/README.md b/README.md index 4ab403eb..3cb780b4 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ Algorithms currently supported by OpenRL (for more details, please refer to [Gal - [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515) - [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476) - [Behavior Cloning (BC)](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf) +- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783) - Self-Play - [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602) - [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953) diff --git a/README_zh.md b/README_zh.md index 8b381027..4af4d091 100644 --- a/README_zh.md +++ b/README_zh.md @@ -74,6 +74,7 @@ OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)): - [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515) - [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476) - [Behavior Cloning (BC)](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf) +- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783) - Self-Play - [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602) - [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953) diff --git a/examples/cartpole/README.md b/examples/cartpole/README.md index 5d939de0..ffc2b5d9 100644 --- a/examples/cartpole/README.md +++ b/examples/cartpole/README.md @@ -13,6 +13,12 @@ To train with [Dual-clip PPO](https://arxiv.org/abs/1912.09729): python train_ppo.py --config dual_clip_ppo.yaml ``` +To train with [A2C](https://arxiv.org/abs/1602.01783) algorithm: + +```shell +python train_a2c.py +``` + If you want to evaluate the agent during training and save the best model and save checkpoints, try to train with callbacks: ```shell diff --git a/examples/cartpole/a2c.yaml b/examples/cartpole/a2c.yaml new file mode 100644 index 00000000..3471be76 --- /dev/null +++ b/examples/cartpole/a2c.yaml @@ -0,0 +1,3 @@ +seed: 0 +run_dir: ./run_results/ +wandb_entity: openrl-lab \ No newline at end of file diff --git a/examples/cartpole/train_a2c.py b/examples/cartpole/train_a2c.py new file mode 100644 index 00000000..3d200ec6 --- /dev/null +++ b/examples/cartpole/train_a2c.py @@ -0,0 +1,70 @@ +"""""" +import numpy as np +import torch + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.modules.common import A2CNet as Net +from openrl.runners.common import A2CAgent as Agent + + +def train(): + # create the neural network + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "a2c.yaml"]) + + # create environment, set environment parallelism to 9 + env = make("CartPole-v1", env_num=9) + + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") + # initialize the trainer + agent = Agent(net, use_wandb=False, project_name="CartPole-v1") + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=30000) + + env.close() + + agent.save("./a2c_agent") + return agent + + +def evaluation(): + # begin to test + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "a2c.yaml"]) + + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = "group_human" + render_mode = None + env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) + + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") + # initialize the trainer + agent = Agent( + net, + ) + agent.load("./a2c_agent") + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + + total_step = 0 + total_reward = 0.0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + total_step += 1 + if total_step % 50 == 0: + print(f"{total_step}: reward:{np.mean(r)}") + env.close() + print("total step:", total_step) + print("total reward:", total_reward) + + +if __name__ == "__main__": + train() + evaluation() diff --git a/openrl/algorithms/a2c.py b/openrl/algorithms/a2c.py new file mode 100644 index 00000000..478675a0 --- /dev/null +++ b/openrl/algorithms/a2c.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Union + +import numpy as np +import torch +from torch.nn.parallel import DistributedDataParallel + +from openrl.algorithms.ppo import PPOAlgorithm + + +class A2CAlgorithm(PPOAlgorithm): + def __init__( + self, + cfg, + init_module, + agent_num: int = 1, + device: Union[str, torch.device] = "cpu", + ) -> None: + super(A2CAlgorithm, self).__init__(cfg, init_module, agent_num, device) + + self.num_mini_batch = 1 + + def prepare_loss( + self, + critic_obs_batch, + obs_batch, + rnn_states_batch, + rnn_states_critic_batch, + actions_batch, + masks_batch, + action_masks_batch, + old_action_log_probs_batch, + adv_targ, + value_preds_batch, + return_batch, + active_masks_batch, + turn_on, + ): + if self.use_joint_action_loss: + critic_obs_batch = self.to_single_np(critic_obs_batch) + rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch) + critic_masks_batch = self.to_single_np(masks_batch) + value_preds_batch = self.to_single_np(value_preds_batch) + return_batch = self.to_single_np(return_batch) + adv_targ = adv_targ.reshape(-1, self.agent_num, 1) + adv_targ = adv_targ[:, 0, :] + + else: + critic_masks_batch = masks_batch + + ( + values, + action_log_probs, + dist_entropy, + policy_values, + ) = self.algo_module.evaluate_actions( + critic_obs_batch, + obs_batch, + rnn_states_batch, + rnn_states_critic_batch, + actions_batch, + masks_batch, + action_masks_batch, + active_masks_batch, + critic_masks_batch=critic_masks_batch, + ) + + if self.use_joint_action_loss: + active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1) + active_masks_batch = active_masks_batch[:, 0, :] + + policy_gradient_loss = -adv_targ.detach() * action_log_probs + if self._use_policy_active_masks: + policy_action_loss = ( + torch.sum(policy_gradient_loss, dim=-1, keepdim=True) + * active_masks_batch + ).sum() / active_masks_batch.sum() + else: + policy_action_loss = torch.sum( + policy_gradient_loss, dim=-1, keepdim=True + ).mean() + + if self._use_policy_vhead: + if isinstance(self.algo_module.models["actor"], DistributedDataParallel): + policy_value_normalizer = self.algo_module.models[ + "actor" + ].module.value_normalizer + else: + policy_value_normalizer = self.algo_module.models[ + "actor" + ].value_normalizer + policy_value_loss = self.cal_value_loss( + policy_value_normalizer, + policy_values, + value_preds_batch, + return_batch, + active_masks_batch, + ) + policy_loss = ( + policy_action_loss + policy_value_loss * self.policy_value_loss_coef + ) + else: + policy_loss = policy_action_loss + + # critic update + if self._use_share_model: + value_normalizer = self.algo_module.models["model"].value_normalizer + elif isinstance(self.algo_module.models["critic"], DistributedDataParallel): + value_normalizer = self.algo_module.models["critic"].module.value_normalizer + else: + value_normalizer = self.algo_module.get_critic_value_normalizer() + value_loss = self.cal_value_loss( + value_normalizer, + values, + value_preds_batch, + return_batch, + active_masks_batch, + ) + + loss_list = self.construct_loss_list( + policy_loss, dist_entropy, value_loss, turn_on + ) + ratio = np.zeros(1) + return loss_list, value_loss, policy_loss, dist_entropy, ratio + + def train(self, buffer, turn_on: bool = True): + train_info = super(A2CAlgorithm, self).train(buffer, turn_on) + train_info.pop("ratio", None) + return train_info diff --git a/openrl/modules/common/__init__.py b/openrl/modules/common/__init__.py index 3a41122f..568a4050 100644 --- a/openrl/modules/common/__init__.py +++ b/openrl/modules/common/__init__.py @@ -1,3 +1,4 @@ +from .a2c_net import A2CNet from .base_net import BaseNet from .bc_net import BCNet from .ddpg_net import DDPGNet @@ -18,4 +19,5 @@ "GAILNet", "BCNet", "SACNet", + "A2CNet", ] diff --git a/openrl/modules/common/a2c_net.py b/openrl/modules/common/a2c_net.py new file mode 100644 index 00000000..e7dbacc5 --- /dev/null +++ b/openrl/modules/common/a2c_net.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from openrl.modules.common.ppo_net import PPONet + + +class A2CNet(PPONet): + pass diff --git a/openrl/runners/common/__init__.py b/openrl/runners/common/__init__.py index d315ecf3..523e4849 100644 --- a/openrl/runners/common/__init__.py +++ b/openrl/runners/common/__init__.py @@ -1,3 +1,4 @@ +from openrl.runners.common.a2c_agent import A2CAgent from openrl.runners.common.bc_agent import BCAgent from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent from openrl.runners.common.ddpg_agent import DDPGAgent @@ -19,4 +20,5 @@ "GAILAgent", "BCAgent", "SACAgent", + "A2CAgent", ] diff --git a/openrl/runners/common/a2c_agent.py b/openrl/runners/common/a2c_agent.py new file mode 100644 index 00000000..5b7e016d --- /dev/null +++ b/openrl/runners/common/a2c_agent.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Optional, Type, Union + +import gym +import torch + +from openrl.algorithms.a2c import A2CAlgorithm +from openrl.algorithms.base_algorithm import BaseAlgorithm +from openrl.drivers.base_driver import BaseDriver +from openrl.drivers.onpolicy_driver import OnPolicyDriver as Driver +from openrl.modules.common import BaseNet +from openrl.runners.common.base_agent import SelfAgent +from openrl.runners.common.ppo_agent import PPOAgent +from openrl.utils.logger import Logger +from openrl.utils.type_aliases import MaybeCallback + + +class A2CAgent(PPOAgent): + def __init__( + self, + net: Optional[Union[torch.nn.Module, BaseNet]] = None, + env: Union[gym.Env, str] = None, + run_dir: Optional[str] = None, + env_num: Optional[int] = None, + rank: int = 0, + world_size: int = 1, + use_wandb: bool = False, + use_tensorboard: bool = False, + project_name: str = "GAILAgent", + ) -> None: + super(A2CAgent, self).__init__( + net, + env, + run_dir, + env_num, + rank, + world_size, + use_wandb, + use_tensorboard, + project_name=project_name, + ) + + def train( + self: SelfAgent, + total_time_steps: int, + callback: MaybeCallback = None, + train_algo_class: Type[BaseAlgorithm] = A2CAlgorithm, + logger: Optional[Logger] = None, + driver_class: Type[BaseDriver] = Driver, + ) -> None: + super().train( + total_time_steps, callback, train_algo_class, logger, driver_class + ) diff --git a/openrl/runners/common/bc_agent.py b/openrl/runners/common/bc_agent.py index dbdc1380..14877043 100644 --- a/openrl/runners/common/bc_agent.py +++ b/openrl/runners/common/bc_agent.py @@ -62,12 +62,12 @@ def train( callback: MaybeCallback = None, train_algo_class: Type[BaseAlgorithm] = BCAlgorithm, logger: Optional[Logger] = None, - DriverClass: Type[BaseDriver] = Driver, + driver_class: Type[BaseDriver] = Driver, ) -> None: super().train( total_time_steps, callback, train_algo_class, logger, - DriverClass=DriverClass, + driver_class=driver_class, ) diff --git a/openrl/runners/common/ppo_agent.py b/openrl/runners/common/ppo_agent.py index b60c2118..ad7d0a84 100644 --- a/openrl/runners/common/ppo_agent.py +++ b/openrl/runners/common/ppo_agent.py @@ -67,7 +67,7 @@ def train( callback: MaybeCallback = None, train_algo_class: Type[BaseAlgorithm] = PPOAlgorithm, logger: Optional[Logger] = None, - DriverClass: Type[BaseDriver] = Driver, + driver_class: Type[BaseDriver] = Driver, ) -> None: self._cfg.num_env_steps = total_time_steps @@ -113,7 +113,7 @@ def train( progress_bar=False, ) - driver = DriverClass( + driver = driver_class( config=self.config, trainer=trainer, buffer=buffer, diff --git a/openrl/runners/common/sac_agent.py b/openrl/runners/common/sac_agent.py index d0fc3790..61903c7b 100644 --- a/openrl/runners/common/sac_agent.py +++ b/openrl/runners/common/sac_agent.py @@ -64,7 +64,7 @@ def train( callback: MaybeCallback = None, train_algo_class: Type[BaseAlgorithm] = TrainAlgo, logger: Optional[Logger] = None, - DriverClass: Type[BaseDriver] = Driver, + driver_class: Type[BaseDriver] = Driver, ) -> None: self._cfg.num_env_steps = total_time_steps @@ -112,7 +112,7 @@ def train( progress_bar=False, ) - driver = DriverClass( + driver = driver_class( config=self.config, trainer=trainer, buffer=buffer,