diff --git a/examples/atari/README.md b/examples/atari/README.md index d0154502d..561255b20 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -1,8 +1,20 @@ -# Atari +# Atari Environment -The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network. +## EnvPool -The Atari env seed cannot be fixed due to the discussion [here](https://github.com/openai/gym/issues/1478), but it is not a big issue since on Atari it will always have the similar results. +We highly recommend using envpool to run the following experiments. To install, in a linux machine, type: + +```bash +pip install envpool +``` + +After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below. + +For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool). + +## ALE-py + +The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase. diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index dcd1911dc..4d5a91c2b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -5,11 +5,10 @@ import numpy as np import torch from atari_network import C51 -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -19,6 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) @@ -54,38 +54,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_c51(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_c51(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) @@ -198,7 +183,7 @@ def watch(): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 685ebb40b..4bae24baa 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -5,11 +5,10 @@ import numpy as np import torch from atari_network import DQN -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import DQNPolicy from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.trainer import offpolicy_trainer @@ -21,6 +20,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) @@ -78,38 +78,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_dqn(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_dqn(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 99f8957c4..4a8133556 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -5,11 +5,10 @@ import numpy as np import torch from atari_network import DQN -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import FQFPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -20,6 +19,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=3128) + parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) @@ -57,38 +57,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_fqf(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_fqf(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model feature_net = DQN( *args.state_shape, args.action_shape, args.device, features_only=True @@ -215,7 +200,7 @@ def watch(): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 532d59482..0ae1d94a7 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -5,11 +5,10 @@ import numpy as np import torch from atari_network import DQN -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import IQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -20,6 +19,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) @@ -57,38 +57,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_iqn(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_iqn(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model feature_net = DQN( *args.state_shape, args.action_shape, args.device, features_only=True @@ -210,7 +195,7 @@ def watch(): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index b123f1078..94e2c3f47 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -5,12 +5,11 @@ import numpy as np import torch from atari_network import DQN -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -87,41 +86,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind( - args.task, frame_stack=args.frames_stack, scale=args.scale_obs - ) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_ppo(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False, - scale=args.scale_obs ) - - -def test_ppo(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = DQN( *args.state_shape, @@ -167,7 +148,7 @@ def dist(p): value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv + recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: feature_net = DQN( @@ -180,7 +161,7 @@ def dist(p): feature_dim, action_dim, hidden_sizes=args.hidden_sizes, - device=args.device + device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) policy = ICMPolicy( @@ -198,7 +179,7 @@ def dist(p): buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack + stack_num=args.frames_stack, ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) @@ -248,7 +229,7 @@ def watch(): buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack + stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index af5d78e3f..1ae4a8c88 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -5,11 +5,10 @@ import numpy as np import torch from atari_network import QRDQN -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import QRDQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -19,6 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) @@ -52,38 +52,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_qrdqn(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_qrdqn(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) @@ -194,7 +179,7 @@ def watch(): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 4e1a78ced..478a6d2b8 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -6,11 +6,10 @@ import numpy as np import torch from atari_network import Rainbow -from atari_wrapper import wrap_deepmind +from atari_wrapper import make_atari_env from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import RainbowPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -20,6 +19,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) @@ -64,38 +64,23 @@ def get_args(): return parser.parse_args() -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_rainbow(args=get_args()): + env, train_envs, test_envs = make_atari_env( args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_rainbow(args=get_args()): - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [lambda: make_atari_env(args) for _ in range(args.training_num)] - ) - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = Rainbow( *args.state_shape, @@ -242,7 +227,7 @@ def watch(): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index ad58ad231..4aca61218 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -1,12 +1,20 @@ # Borrow a lot from openai baselines: # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +import warnings from collections import deque import cv2 import gym import numpy as np +try: + import envpool +except ImportError: + envpool = None + +from tianshou.env import ShmemVectorEnv + class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. @@ -245,3 +253,59 @@ def wrap_deepmind( if frame_stack: env = FrameStack(env, frame_stack) return env + + +def make_atari_env(task, seed, training_num, test_num, **kwargs): + """Wrapper function for Atari env. + + If EnvPool is installed, it will automatically switch to EnvPool's Atari env. + + :return: a tuple of (single env, training envs, test envs). + """ + if envpool is not None: + if kwargs.get("scale", 0): + warnings.warn( + "EnvPool does not include ScaledFloatFrame wrapper, " + "please set `x = x / 255.0` inside CNN network's forward function." + ) + # parameters convertion + train_envs = env = envpool.make_gym( + task.replace("NoFrameskip-v4", "-v5"), + num_envs=training_num, + seed=seed, + episodic_life=True, + reward_clip=True, + stack_num=kwargs.get("frame_stack", 4), + ) + test_envs = envpool.make_gym( + task.replace("NoFrameskip-v4", "-v5"), + num_envs=training_num, + seed=seed, + episodic_life=False, + reward_clip=False, + stack_num=kwargs.get("frame_stack", 4), + ) + else: + warnings.warn( + "Recommend using envpool (pip install envpool) " + "to run Atari games more efficiently." + ) + env = wrap_deepmind(task, **kwargs) + train_envs = ShmemVectorEnv( + [ + lambda: + wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs) + for _ in range(training_num) + ] + ) + test_envs = ShmemVectorEnv( + [ + lambda: + wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs) + for _ in range(test_num) + ] + ) + env.seed(seed) + train_envs.seed(seed) + test_envs.seed(seed) + return env, train_envs, test_envs diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 1763380f1..e952294f5 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -19,7 +19,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000)