Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add A2C algorithm #226

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Users are also welcome to contribute their own training examples and demos to th
| [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [GAIL](https://arxiv.org/abs/1606.03476) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/gail/) |
| [Behavior Cloning](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/behavior_cloning/) |
| [A2C](https://arxiv.org/abs/1602.01783) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| Self-Play | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
| [DQN](https://arxiv.org/abs/1312.5602) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![value](https://img.shields.io/badge/-value-orange) ![offpolicy](https://img.shields.io/badge/-offpolicy-blue) | [code](./examples/toy_env) [code](./examples/gridworld/) |
| [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/mpe/) |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Algorithms currently supported by OpenRL (for more details, please refer to [Gal
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476)
- [Behavior Cloning (BC)](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf)
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
- Self-Play
- [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)):
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476)
- [Behavior Cloning (BC)](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf)
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
- Self-Play
- [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)
Expand Down
6 changes: 6 additions & 0 deletions examples/cartpole/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ To train with [Dual-clip PPO](https://arxiv.org/abs/1912.09729):
python train_ppo.py --config dual_clip_ppo.yaml
```

To train with [A2C](https://arxiv.org/abs/1602.01783) algorithm:

```shell
python train_a2c.py
```

If you want to evaluate the agent during training and save the best model and save checkpoints, try to train with callbacks:

```shell
Expand Down
3 changes: 3 additions & 0 deletions examples/cartpole/a2c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
seed: 0
run_dir: ./run_results/
wandb_entity: openrl-lab
70 changes: 70 additions & 0 deletions examples/cartpole/train_a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""""""
import numpy as np
import torch

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import A2CNet as Net
from openrl.runners.common import A2CAgent as Agent


def train():
# create the neural network
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "a2c.yaml"])

# create environment, set environment parallelism to 9
env = make("CartPole-v1", env_num=9)

net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
# initialize the trainer
agent = Agent(net, use_wandb=False, project_name="CartPole-v1")
# start training, set total number of training steps to 20000
agent.train(total_time_steps=30000)

env.close()

agent.save("./a2c_agent")
return agent


def evaluation():
# begin to test

cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "a2c.yaml"])

# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = "group_human"
render_mode = None
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)

net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
# initialize the trainer
agent = Agent(
net,
)
agent.load("./a2c_agent")
# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()
done = False

total_step = 0
total_reward = 0.0
while not np.any(done):
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
total_step += 1
if total_step % 50 == 0:
print(f"{total_step}: reward:{np.mean(r)}")
env.close()
print("total step:", total_step)
print("total reward:", total_reward)


if __name__ == "__main__":
train()
evaluation()
145 changes: 145 additions & 0 deletions openrl/algorithms/a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import Union

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel

from openrl.algorithms.ppo import PPOAlgorithm


class A2CAlgorithm(PPOAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
super(A2CAlgorithm, self).__init__(cfg, init_module, agent_num, device)

self.num_mini_batch = 1

def prepare_loss(
self,
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
):
if self.use_joint_action_loss:
critic_obs_batch = self.to_single_np(critic_obs_batch)
rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch)
critic_masks_batch = self.to_single_np(masks_batch)
value_preds_batch = self.to_single_np(value_preds_batch)
return_batch = self.to_single_np(return_batch)
adv_targ = adv_targ.reshape(-1, self.agent_num, 1)
adv_targ = adv_targ[:, 0, :]

else:
critic_masks_batch = masks_batch

(
values,
action_log_probs,
dist_entropy,
policy_values,
) = self.algo_module.evaluate_actions(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)

if self.use_joint_action_loss:
active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
active_masks_batch = active_masks_batch[:, 0, :]

policy_gradient_loss = -adv_targ.detach() * action_log_probs
if self._use_policy_active_masks:
policy_action_loss = (
torch.sum(policy_gradient_loss, dim=-1, keepdim=True)
* active_masks_batch
).sum() / active_masks_batch.sum()
else:
policy_action_loss = torch.sum(
policy_gradient_loss, dim=-1, keepdim=True
).mean()

if self._use_policy_vhead:
if isinstance(self.algo_module.models["actor"], DistributedDataParallel):
policy_value_normalizer = self.algo_module.models[
"actor"
].module.value_normalizer
else:
policy_value_normalizer = self.algo_module.models[
"actor"
].value_normalizer
policy_value_loss = self.cal_value_loss(
policy_value_normalizer,
policy_values,
value_preds_batch,
return_batch,
active_masks_batch,
)
policy_loss = (
policy_action_loss + policy_value_loss * self.policy_value_loss_coef
)
else:
policy_loss = policy_action_loss

# critic update
if self._use_share_model:
value_normalizer = self.algo_module.models["model"].value_normalizer
elif isinstance(self.algo_module.models["critic"], DistributedDataParallel):
value_normalizer = self.algo_module.models["critic"].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
value_loss = self.cal_value_loss(
value_normalizer,
values,
value_preds_batch,
return_batch,
active_masks_batch,
)

loss_list = self.construct_loss_list(
policy_loss, dist_entropy, value_loss, turn_on
)
ratio = np.zeros(1)
return loss_list, value_loss, policy_loss, dist_entropy, ratio

def train(self, buffer, turn_on: bool = True):
train_info = super(A2CAlgorithm, self).train(buffer, turn_on)
train_info.pop("ratio", None)
return train_info
2 changes: 2 additions & 0 deletions openrl/modules/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .a2c_net import A2CNet
from .base_net import BaseNet
from .bc_net import BCNet
from .ddpg_net import DDPGNet
Expand All @@ -18,4 +19,5 @@
"GAILNet",
"BCNet",
"SACNet",
"A2CNet",
]
22 changes: 22 additions & 0 deletions openrl/modules/common/a2c_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from openrl.modules.common.ppo_net import PPONet


class A2CNet(PPONet):
pass
2 changes: 2 additions & 0 deletions openrl/runners/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from openrl.runners.common.a2c_agent import A2CAgent
from openrl.runners.common.bc_agent import BCAgent
from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent
from openrl.runners.common.ddpg_agent import DDPGAgent
Expand All @@ -19,4 +20,5 @@
"GAILAgent",
"BCAgent",
"SACAgent",
"A2CAgent",
]
69 changes: 69 additions & 0 deletions openrl/runners/common/a2c_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import Optional, Type, Union

import gym
import torch

from openrl.algorithms.a2c import A2CAlgorithm
from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.drivers.base_driver import BaseDriver
from openrl.drivers.onpolicy_driver import OnPolicyDriver as Driver
from openrl.modules.common import BaseNet
from openrl.runners.common.base_agent import SelfAgent
from openrl.runners.common.ppo_agent import PPOAgent
from openrl.utils.logger import Logger
from openrl.utils.type_aliases import MaybeCallback


class A2CAgent(PPOAgent):
def __init__(
self,
net: Optional[Union[torch.nn.Module, BaseNet]] = None,
env: Union[gym.Env, str] = None,
run_dir: Optional[str] = None,
env_num: Optional[int] = None,
rank: int = 0,
world_size: int = 1,
use_wandb: bool = False,
use_tensorboard: bool = False,
project_name: str = "GAILAgent",
) -> None:
super(A2CAgent, self).__init__(
net,
env,
run_dir,
env_num,
rank,
world_size,
use_wandb,
use_tensorboard,
project_name=project_name,
)

def train(
self: SelfAgent,
total_time_steps: int,
callback: MaybeCallback = None,
train_algo_class: Type[BaseAlgorithm] = A2CAlgorithm,
logger: Optional[Logger] = None,
driver_class: Type[BaseDriver] = Driver,
) -> None:
super().train(
total_time_steps, callback, train_algo_class, logger, driver_class
)
4 changes: 2 additions & 2 deletions openrl/runners/common/bc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def train(
callback: MaybeCallback = None,
train_algo_class: Type[BaseAlgorithm] = BCAlgorithm,
logger: Optional[Logger] = None,
DriverClass: Type[BaseDriver] = Driver,
driver_class: Type[BaseDriver] = Driver,
) -> None:
super().train(
total_time_steps,
callback,
train_algo_class,
logger,
DriverClass=DriverClass,
driver_class=driver_class,
)
Loading