Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add offline trainer and discrete BCQ algorithm #263

Merged
merged 29 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1cdc203
work
zhujl1991 Dec 15, 2020
ae8c12e
removing_comments
zhujl1991 Dec 15, 2020
b5eba6c
removing_comments
zhujl1991 Dec 15, 2020
5e6873e
format
zhujl1991 Dec 15, 2020
8035e8b
cleaning
zhujl1991 Dec 15, 2020
5c52a18
almost
zhujl1991 Dec 15, 2020
9bf02be
feedback
zhujl1991 Jan 5, 2021
288c154
lint
zhujl1991 Jan 5, 2021
6bc7ac7
Merge branch 'master' into work
Trinkle23897 Jan 6, 2021
7919fc6
Merge branch 'master' into work
Trinkle23897 Jan 6, 2021
48f0012
.
zhujl1991 Jan 6, 2021
9f76425
Merge branch 'work' of github.com:zhujl1991/tianshou into work
zhujl1991 Jan 6, 2021
f4b9aa6
resolve #269, #270
Trinkle23897 Jan 12, 2021
36a137c
update BCQPolicy and BCQN
Trinkle23897 Jan 12, 2021
2e22daf
runnable
Trinkle23897 Jan 13, 2021
e63cb50
polish
Trinkle23897 Jan 13, 2021
21705b7
fix unacessary relu layer in network
Trinkle23897 Jan 14, 2021
d8be9ed
final
Trinkle23897 Jan 14, 2021
8489d17
fix
Trinkle23897 Jan 14, 2021
c2cf972
add atari_bcq, still need check
Trinkle23897 Jan 15, 2021
a3b51a2
update examples
Trinkle23897 Jan 16, 2021
1d34109
Merge branch 'master' into work
Trinkle23897 Jan 16, 2021
667b2f8
tune eps code
Trinkle23897 Jan 16, 2021
04b1379
Merge branch 'work' of github.com:zhujl1991/tianshou into work
Trinkle23897 Jan 16, 2021
151fd0b
fix eps mask
Trinkle23897 Jan 16, 2021
705a919
Merge branch 'master' into work
Trinkle23897 Jan 20, 2021
5ef0c4c
fix test
Trinkle23897 Jan 20, 2021
a37a542
update readme
Trinkle23897 Jan 20, 2021
0b291de
trailing comma
Trinkle23897 Jan 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,5 @@ MUJOCO_LOG.TXT
*.zip
*.pstats
*.swp
*.pkl
*.hdf5
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ Trainer

Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.

Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.


.. _pseudocode:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ In each step, the collector will let the policy perform (at least) a specified n
Train Policy with a Trainer
---------------------------

Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :func:`~tianshou.trainer.offpolicy_trainer` as follows:
::

result = ts.trainer.offpolicy_trainer(
Expand All @@ -133,7 +133,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
writer=None)
print(f'Finished training! Use {result["duration"]}')

The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):

* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
Expand Down
13 changes: 12 additions & 1 deletion examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |

Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.

# BCQ

TODO: after the `done` issue fixed, the result should be re-tuned and place here.

To running BCQ algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.

153 changes: 153 additions & 0 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
from tianshou.policy import DiscreteBCQPolicy
from tianshou.data import Collector, ReplayBuffer

from atari_network import DQN
from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=6.25e-5)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=8000)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=1000)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_discrete_bcq(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
feature_net = DQN(*args.state_shape, args.action_shape,
device=args.device, features_only=True).to(args.device)
policy_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
imitation_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr,
)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
args.target_update_freq, args.eps_test,
args.unlikely_action_threshold, args.imitation_logits_penalty,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)

# collector
test_collector = Collector(policy, test_envs)

log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return False

# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
watch()
exit(0)

result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
log_interval=args.log_interval,
)

pprint.pprint(result)
watch()


if __name__ == "__main__":
test_discrete_bcq(get_args())
21 changes: 17 additions & 4 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_args():
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -120,13 +121,25 @@ def test_fn(epoch, env_step):

# watch agent's performance
def watch():
print("Testing agent ...")
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = ReplayBuffer(
args.buffer_size, ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
Expand Down
3 changes: 1 addition & 2 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def forward(
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(
x, device=self.device, dtype=torch.float32) # type: ignore
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
return self.net(x), state


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_version() -> str:
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=3.1.0"
"h5py>=2.10.0", # to match tensorflow's minimal requirements
],
extras_require={
"dev": [
Expand Down
11 changes: 11 additions & 0 deletions test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
Expand Down Expand Up @@ -38,6 +39,9 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str,
default="./expert_DQN_CartPole-v0.pkl")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -114,6 +118,7 @@ def test_fn(epoch, env_step):
stop_fn=stop_fn, save_fn=save_fn, writer=writer)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -124,6 +129,12 @@ def test_fn(epoch, env_step):
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')

# save buffer in pickle format, for imitation learning unittest
buf = ReplayBuffer(args.buffer_size)
collector = Collector(policy, test_envs, buf)
collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))


def test_pdqn(args=get_args()):
args.prioritized_replay = True
Expand Down
Loading