diff --git a/examples/atari/README.md b/examples/atari/README.md index 51b1af931..d89d9f217 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -95,3 +95,17 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | + +# PPO (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` | +| BreakoutNoFrameskip-v4 | 442.1 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` | +| EnduroNoFrameskip-v4 | 1386.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 19585 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2319 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 1764 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 1184 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` | diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index e12dee00c..ec705e7de 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -22,6 +22,7 @@ def __init__( action_shape: Sequence[int], device: Union[str, int, torch.device] = "cpu", features_only: bool = False, + output_dim: Optional[int] = None, ) -> None: super().__init__() self.device = device @@ -39,6 +40,12 @@ def __init__( nn.Linear(512, np.prod(action_shape)) ) self.output_dim = np.prod(action_shape) + elif output_dim is not None: + self.net = nn.Sequential( + self.net, nn.Linear(self.output_dim, output_dim), + nn.ReLU(inplace=True) + ) + self.output_dim = output_dim def forward( self, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py new file mode 100644 index 000000000..668d036fa --- /dev/null +++ b/examples/atari/atari_ppo.py @@ -0,0 +1,297 @@ +import argparse +import os +import pprint + +import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import ShmemVectorEnv +from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=4213) + parser.add_argument('--scale-obs', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=1000) + parser.add_argument('--repeat-per-collect', type=int, default=4) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--hidden-size', type=int, default=512) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--rew-norm', type=int, default=False) + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--ent-coef', type=float, default=0.01) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--lr-decay', type=int, default=True) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=0) + parser.add_argument('--norm-adv', type=int, default=1) + parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--resume-id', type=str, default=None) + parser.add_argument( + '--logger', + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument( + '--icm-lr-scale', + type=float, + default=0., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--icm-reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--icm-forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind( + args.task, frame_stack=args.frames_stack, scale=args.scale_obs + ) + + +def make_atari_env_watch(args): + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False, + scale=args.scale_obs + ) + + +def test_ppo(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = ShmemVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = ShmemVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = DQN( + *args.state_shape, + args.action_shape, + device=args.device, + features_only=True, + output_dim=args.hidden_size + ) + actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + critic = Critic(net, device=args.device) + optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + + lr_scheduler = None + if args.lr_decay: + # decay learning rate to 0 linearly + max_update_num = np.ceil( + args.step_per_epoch / args.step_per_collect + ) * args.epoch + + lr_scheduler = LambdaLR( + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) + + # define policy + def dist(p): + return torch.distributions.Categorical(logits=p) + + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=False, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ).to(args.device) + if args.icm_lr_scale > 0: + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device + ) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, + args.icm_forward_loss_weight + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' + log_path = os.path.join(args.logdir, args.task, log_name) + if args.logger == "tensorboard": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + else: + logger = WandbLogger( + save_interval=1, + project=args.task, + name=log_name, + run_id=args.resume_id, + config=args, + ) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + torch.save({'model': policy.state_dict()}, ckpt_path) + return ckpt_path + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_ppo(get_args()) diff --git a/examples/atari/results/ppo/Breakout_rew.png b/examples/atari/results/ppo/Breakout_rew.png new file mode 100644 index 000000000..8625787d7 Binary files /dev/null and b/examples/atari/results/ppo/Breakout_rew.png differ diff --git a/examples/atari/results/ppo/Enduro_rew.png b/examples/atari/results/ppo/Enduro_rew.png new file mode 100644 index 000000000..50a23fa76 Binary files /dev/null and b/examples/atari/results/ppo/Enduro_rew.png differ diff --git a/examples/atari/results/ppo/MsPacman_rew.png b/examples/atari/results/ppo/MsPacman_rew.png new file mode 100644 index 000000000..34836550b Binary files /dev/null and b/examples/atari/results/ppo/MsPacman_rew.png differ diff --git a/examples/atari/results/ppo/Pong_rew.png b/examples/atari/results/ppo/Pong_rew.png new file mode 100644 index 000000000..c52fdc202 Binary files /dev/null and b/examples/atari/results/ppo/Pong_rew.png differ diff --git a/examples/atari/results/ppo/Qbert_rew.png b/examples/atari/results/ppo/Qbert_rew.png new file mode 100644 index 000000000..03c83ddac Binary files /dev/null and b/examples/atari/results/ppo/Qbert_rew.png differ diff --git a/examples/atari/results/ppo/Seaquest_rew.png b/examples/atari/results/ppo/Seaquest_rew.png new file mode 100644 index 000000000..675013356 Binary files /dev/null and b/examples/atari/results/ppo/Seaquest_rew.png differ diff --git a/examples/atari/results/ppo/SpaceInvaders_rew.png b/examples/atari/results/ppo/SpaceInvaders_rew.png new file mode 100644 index 000000000..4c090a906 Binary files /dev/null and b/examples/atari/results/ppo/SpaceInvaders_rew.png differ