Skip to content

Commit

Permalink
Make trainer resumable (#350)
Browse files Browse the repository at this point in the history
- specify tensorboard >= 2.5.0
- add `save_checkpoint_fn` and `resume_from_log` in trainer

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
  • Loading branch information
StephenArk30 and Trinkle23897 authored May 6, 2021
1 parent f4e05d5 commit 84f5863
Show file tree
Hide file tree
Showing 24 changed files with 308 additions and 77 deletions.
28 changes: 28 additions & 0 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,34 @@ Customize Training Process
See :ref:`customized_trainer`.


.. _resume_training:

Resume Training Process
-----------------------

This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.

To resume training process from an existing checkpoint, you need to do the following things in the training process:

1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.

And to successfully resume from a checkpoint:

1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;

We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running

.. code-block:: console
$ python3 test/discrete/test_c51.py # train some epoch
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training
To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here <https://github.com/thu-ml/tianshou/pull/350#issuecomment-829123378>`_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with.

.. _parallel_sampling:

Parallel Sampling
Expand Down
3 changes: 1 addition & 2 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def test_discrete_bcq(args=get_args()):
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr)
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_a2c(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()),
lr=args.lr, eps=1e-5, alpha=0.99)

lr_scheduler = None
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_version() -> str:
"gym>=0.15.4",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"tensorboard>=2.5.0",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
Expand Down
1 change: 1 addition & 0 deletions test/continuous/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
3 changes: 1 addition & 2 deletions test/continuous/test_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_npg(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down
40 changes: 34 additions & 6 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_args():
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -83,8 +85,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down Expand Up @@ -114,21 +116,42 @@ def dist(*logits):
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = BasicLogger(writer, save_interval=args.save_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
logger=logger, resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -140,5 +163,10 @@ def stop_fn(mean_rewards):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_ppo_resume(args=get_args()):
args.resume = True
test_ppo(args)


if __name__ == '__main__':
test_ppo()
1 change: 1 addition & 0 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
1 change: 1 addition & 0 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
6 changes: 3 additions & 3 deletions test/continuous/test_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int,
default=2) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
Expand Down Expand Up @@ -82,8 +83,7 @@ def test_trpo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down
6 changes: 4 additions & 2 deletions test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_a2c_with_il(args=get_args()):
device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist,
Expand Down Expand Up @@ -106,6 +106,7 @@ def stop_fn(mean_rewards):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down Expand Up @@ -135,6 +136,7 @@ def stop_fn(mean_rewards):
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
42 changes: 39 additions & 3 deletions test/discrete/test_c51.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
Expand Down Expand Up @@ -43,9 +44,11 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -90,7 +93,7 @@ def test_c51(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = BasicLogger(writer, save_interval=args.save_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -112,14 +115,42 @@ def train_fn(epoch, env_step):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
pickle.dump(train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
else:
print("Fail to restore buffer.")

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)

test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -132,6 +163,11 @@ def test_fn(epoch, env_step):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_c51_resume(args=get_args()):
args.resume = True
test_c51(args)


def test_pc51(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
Expand Down
1 change: 0 additions & 1 deletion test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def test_fn(epoch, env_step):
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_fn(epoch, env_step):
args.batch_size, update_per_step=args.update_per_step,
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
Loading

0 comments on commit 84f5863

Please sign in to comment.