Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make trainer resumable #350

Merged
merged 17 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@ Customize Training Process
See :ref:`customized_trainer`.


.. _resume_training:

Resume Training Process
-----------------------

This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.

To resume training process from an existing checkpoint, you need to do the following things in the training process:

1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;

And to successfully resume from a checkpoint:

1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;

We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: console

$ python3 test/discrete/test_c51.py # train some epoch
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training


.. _parallel_sampling:

Parallel Sampling
Expand Down
3 changes: 1 addition & 2 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def test_discrete_bcq(args=get_args()):
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr)
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_a2c(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()),
lr=args.lr, eps=1e-5, alpha=0.99)

lr_scheduler = None
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
Expand Down
1 change: 1 addition & 0 deletions test/continuous/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
3 changes: 1 addition & 2 deletions test/continuous/test_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_npg(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down
37 changes: 32 additions & 5 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_args():
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -83,8 +84,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down Expand Up @@ -122,13 +123,34 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
logger=logger, resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -140,5 +162,10 @@ def stop_fn(mean_rewards):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_ppo_resume(args=get_args()):
args.resume = True
test_ppo(args)


if __name__ == '__main__':
test_ppo()
1 change: 1 addition & 0 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
1 change: 1 addition & 0 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
6 changes: 3 additions & 3 deletions test/continuous/test_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int,
default=2) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
Expand Down Expand Up @@ -82,8 +83,7 @@ def test_trpo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down
6 changes: 4 additions & 2 deletions test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_a2c_with_il(args=get_args()):
device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist,
Expand Down Expand Up @@ -106,6 +106,7 @@ def stop_fn(mean_rewards):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down Expand Up @@ -135,6 +136,7 @@ def stop_fn(mean_rewards):
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
39 changes: 37 additions & 2 deletions test/discrete/test_c51.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
Expand Down Expand Up @@ -43,6 +44,7 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -112,14 +114,42 @@ def train_fn(epoch, env_step):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
pickle.dump(train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
else:
print("Fail to restore buffer.")

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)

test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -132,6 +162,11 @@ def test_fn(epoch, env_step):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_c51_resume(args=get_args()):
args.resume = True
test_c51(args)


def test_pc51(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
Expand Down
1 change: 0 additions & 1 deletion test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def test_fn(epoch, env_step):
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_fn(epoch, env_step):
args.batch_size, update_per_step=args.update_per_step,
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
31 changes: 28 additions & 3 deletions test/discrete/test_il_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_args():
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume", action="store_true")
args = parser.parse_known_args()[0]
return args

Expand All @@ -67,7 +68,7 @@ def test_discrete_bcq(args=get_args()):
args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
list(policy_net.parameters()) + list(imitation_net.parameters()),
lr=args.lr)

policy = DiscreteBCQPolicy(
Expand All @@ -93,11 +94,30 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")

result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
Expand All @@ -112,5 +132,10 @@ def stop_fn(mean_rewards):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_discrete_bcq_resume(args=get_args()):
args.resume = True
test_discrete_bcq(args)


if __name__ == "__main__":
test_discrete_bcq(get_args())
1 change: 1 addition & 0 deletions test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def stop_fn(mean_rewards):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
5 changes: 3 additions & 2 deletions test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist,
Expand Down Expand Up @@ -113,6 +113,7 @@ def stop_fn(mean_rewards):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def test_fn(epoch, env_step):
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
Loading