Skip to content

Commit

Permalink
update examples/atari with envpool
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Feb 24, 2022
1 parent 6bc2055 commit be5589f
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 177 deletions.
18 changes: 15 additions & 3 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
# Atari
# Atari Environment

The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
## EnvPool

The Atari env seed cannot be fixed due to the discussion [here](https://github.com/openai/gym/issues/1478), but it is not a big issue since on Atari it will always have the similar results.
We highly recommend using envpool to run the following experiments. To install, in a linux machine, type:

```bash
pip install envpool
```

After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.

For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).

## ALE-py

The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer).

The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase.

Expand Down
33 changes: 9 additions & 24 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import numpy as np
import torch
from atari_network import C51
from atari_wrapper import wrap_deepmind
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand All @@ -19,6 +18,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
Expand Down Expand Up @@ -54,38 +54,23 @@ def get_args():
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(
def test_c51(args=get_args()):
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)


def test_c51(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
Expand Down Expand Up @@ -198,7 +183,7 @@ def watch():
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
test_in_train=False,
)

pprint.pprint(result)
Expand Down
31 changes: 8 additions & 23 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import offpolicy_trainer
Expand All @@ -21,6 +20,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
Expand Down Expand Up @@ -78,38 +78,23 @@ def get_args():
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(
def test_dqn(args=get_args()):
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)


def test_dqn(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
Expand Down
33 changes: 9 additions & 24 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand All @@ -20,6 +19,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=3128)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
Expand Down Expand Up @@ -57,38 +57,23 @@ def get_args():
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(
def test_fqf(args=get_args()):
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)


def test_fqf(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
feature_net = DQN(
*args.state_shape, args.action_shape, args.device, features_only=True
Expand Down Expand Up @@ -215,7 +200,7 @@ def watch():
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
test_in_train=False,
)

pprint.pprint(result)
Expand Down
33 changes: 9 additions & 24 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand All @@ -20,6 +19,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
Expand Down Expand Up @@ -57,38 +57,23 @@ def get_args():
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(
def test_iqn(args=get_args()):
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)


def test_iqn(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
feature_net = DQN(
*args.state_shape, args.action_shape, args.device, features_only=True
Expand Down Expand Up @@ -210,7 +195,7 @@ def watch():
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
test_in_train=False,
)

pprint.pprint(result)
Expand Down
Loading

0 comments on commit be5589f

Please sign in to comment.