Skip to content

Commit

Permalink
fix atari_bcq (thu-ml#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Apr 20, 2021
1 parent 21764f5 commit 3b1d274
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 47 deletions.
9 changes: 7 additions & 2 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

# BCQ

TODO: after the `done` issue fixed, the result should be re-tuned and place here.

To running BCQ algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.

We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
46 changes: 20 additions & 26 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
from tianshou.policy import DiscreteBCQPolicy
from tianshou.data import Collector, ReplayBuffer
from tianshou.data import Collector, VectorReplayBuffer

from atari_network import DQN
from atari_wrapper import wrap_deepmind
Expand All @@ -25,17 +25,16 @@ def get_args():
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=6.25e-5)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--target-update-freq", type=int, default=8000)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
Expand All @@ -44,12 +43,10 @@ def get_args():
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
)
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -81,33 +78,32 @@ def test_discrete_bcq(args=get_args()):
# model
feature_net = DQN(*args.state_shape, args.action_shape,
device=args.device, features_only=True).to(args.device)
policy_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
imitation_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
policy_net = Actor(
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
imitation_net = Actor(
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr,
)
lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
args.target_update_freq, args.eps_test,
args.unlikely_action_threshold, args.imitation_logits_penalty,
)
args.unlikely_action_threshold, args.imitation_logits_penalty)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
Expand Down Expand Up @@ -146,11 +142,9 @@ def watch():
exit(0)

result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
log_interval=args.log_interval,
)
policy, buffer, test_collector, args.epoch,
args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

pprint.pprint(result)
watch()
Expand Down
25 changes: 21 additions & 4 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_args():
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -128,13 +129,29 @@ def test_fn(epoch, env_step):

# watch agent's performance
def watch():
print("Testing agent ...")
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
6 changes: 4 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def watch():
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand All @@ -144,7 +145,8 @@ def watch():
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
25 changes: 21 additions & 4 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_args():
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -126,13 +127,29 @@ def test_fn(epoch, env_step):

# watch agent's performance
def watch():
print("Testing agent ...")
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
6 changes: 4 additions & 2 deletions test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,11 @@ def test_fn(epoch, env_step):

# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
collector = Collector(policy, test_envs, buf)
collector.collect(n_step=args.buffer_size)
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())


def test_pdqn(args=get_args()):
Expand Down
10 changes: 6 additions & 4 deletions test/discrete/test_il_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ def get_args():
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.6)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--update-per-epoch", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
Expand All @@ -49,6 +49,8 @@ def get_args():
def test_discrete_bcq(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
Expand Down
6 changes: 3 additions & 3 deletions tianshou/policy/imitation/discrete_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

return {
"loss": loss.item(),
"q_loss": q_loss.item(),
"i_loss": i_loss.item(),
"reg_loss": reg_loss.item(),
"loss/q": q_loss.item(),
"loss/i": i_loss.item(),
"loss/reg": reg_loss.item(),
}

0 comments on commit 3b1d274

Please sign in to comment.