Skip to content

Commit

Permalink
ppo benchmark (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenDRAG authored Mar 30, 2021
1 parent 5d580c3 commit 6426a39
Show file tree
Hide file tree
Showing 21 changed files with 424 additions and 177 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
Expand All @@ -39,7 +39,8 @@

Here is Tianshou's other features:

- Elegant framework, using only ~2000 lines of code
- Elegant framework, using only ~3000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
Expand All @@ -30,6 +30,7 @@ Welcome to Tianshou!
Here is Tianshou's other features:

* Elegant framework, using only ~2000 lines of code
* State-of-the-art `MuJoCo benchmark <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco>`_
* Support parallel environment simulation (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling`
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
* Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
Expand Down
265 changes: 151 additions & 114 deletions examples/mujoco/README.md

Large diffs are not rendered by default.

Binary file added examples/mujoco/benchmark/Ant-v3/ppo/figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 22 additions & 12 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
Expand Down Expand Up @@ -36,12 +37,6 @@ def get_args():
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
# a2c special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--vf-coef', type=float, default=0.5)
Expand All @@ -50,6 +45,14 @@ def get_args():
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand Down Expand Up @@ -120,6 +123,11 @@ def dist(*logits):
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
Expand All @@ -138,12 +146,14 @@ def dist(*logits):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
Expand Down
22 changes: 13 additions & 9 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
Expand Down Expand Up @@ -44,6 +45,8 @@ def get_args():
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand Down Expand Up @@ -87,11 +90,10 @@ def test_ddpg(args=get_args()):
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step, action_space=env.action_space)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
Expand All @@ -113,12 +115,14 @@ def test_ddpg(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
if not args.watch:
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
Expand Down
175 changes: 175 additions & 0 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3

import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal

from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
# ppo special
parser.add_argument('--rew-norm', type=int, default=True)
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=0)
parser.add_argument('--recompute-adv', type=int, default=1)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def test_ppo(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
norm_obs=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect) * args.epoch

lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits):
return Independent(Normal(*logits), 1)

policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
reward_normalization=args.rew_norm, action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space,
eps_clip=args.eps_clip, value_clip=args.value_clip,
dual_clip=args.dual_clip, advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ppo'
log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_ppo()
32 changes: 21 additions & 11 deletions examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
Expand Down Expand Up @@ -36,17 +37,19 @@ def get_args():
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
# reinforce special
parser.add_argument('--rew-norm', type=int, default=True)
# "clip" option also works well.
parser.add_argument('--action-bound-method', type=str, default="tanh")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
# reinforce special
parser.add_argument('--rew-norm', type=int, default=True)
# "clip" option also works well.
parser.add_argument('--action-bound-method', type=str, default="tanh")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand Down Expand Up @@ -110,6 +113,11 @@ def dist(*logits):
action_bound_method=args.action_bound_method,
lr_scheduler=lr_scheduler, action_space=env.action_space)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
Expand All @@ -128,12 +136,14 @@ def dist(*logits):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
Expand Down
Loading

0 comments on commit 6426a39

Please sign in to comment.