Skip to content

Commit

Permalink
TRPO benchmark release (thu-ml#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenDRAG authored Apr 19, 2021
1 parent 13f946d commit 21764f5
Show file tree
Hide file tree
Showing 15 changed files with 213 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Trust Region Policy Optimization](https://arxiv.org/pdf/1502.05477.pdf)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
Expand All @@ -41,7 +41,7 @@
Here is Tianshou's other features:

- Elegant framework, using only ~3000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/PPO/DDPG/TD3/SAC algorithms
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
Expand Down
45 changes: 35 additions & 10 deletions examples/mujoco/README.md

Large diffs are not rendered by default.

Binary file added examples/mujoco/benchmark/Ant-v3/trpo/figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=80)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means caculating all data in one singe forward.
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means caculating all data in one singe forward.
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
Expand Down
173 changes: 173 additions & 0 deletions examples/mujoco/mujoco_trpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#!/usr/bin/env python3

import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal

from tianshou.policy import TRPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*',
default=[64, 64]) # baselines [32, 32]
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=1024)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
# trpo special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--gae-lambda', type=float, default=0.95)
# TODO tanh support
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--optim-critic-iters', type=int, default=20)
parser.add_argument('--max-kl', type=float, default=0.01)
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
parser.add_argument('--max-backtracks', type=int, default=10)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def test_trpo(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
norm_obs=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect) * args.epoch

lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits):
return Independent(Normal(*logits), 1)

policy = TRPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm, action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space,
advantage_normalization=args.norm_adv,
optim_critic_iters=args.optim_critic_iters,
max_kl=args.max_kl,
backtrack_coeff=args.backtrack_coeff,
max_backtracks=args.max_backtracks)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_trpo'
log_path = os.path.join(args.logdir, args.task, 'trpo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_trpo()
3 changes: 1 addition & 2 deletions tianshou/policy/modelfree/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product

if kl < self._delta and new_actor_loss < actor_loss:
if i > 0:
warnings.warn(f"Backtracking to step {i}. "
"Hyperparamters aren't good enough.")
warnings.warn(f"Backtracking to step {i}.")
break
elif i < self._max_backtracks - 1:
step_size = step_size * self._backtrack_coeff
Expand Down

0 comments on commit 21764f5

Please sign in to comment.