Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make trainer resumable #350

Merged
merged 17 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@ Customize Training Process
See :ref:`customized_trainer`.


.. _resume_training:

Resume Training Process
-----------------------

This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.

To resume training process from an existing checkpoint, you need to do the following things in the training process:

1. Make sure you write ``save_train_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;

And to successfully resume from a checkpoint:

1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;

We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: console

$ python3 test/discrete/test_c51.py # train some epoch
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training


.. _parallel_sampling:

Parallel Sampling
Expand Down
28 changes: 27 additions & 1 deletion test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_args():
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -122,13 +123,33 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def save_train_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
logger=logger, resume_from_log=args.resume, save_train_fn=save_train_fn)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -140,5 +161,10 @@ def stop_fn(mean_rewards):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_ppo_resume(args=get_args()):
args.resume = True
test_ppo(args)


if __name__ == '__main__':
test_ppo()
38 changes: 37 additions & 1 deletion test/discrete/test_c51.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
Expand Down Expand Up @@ -43,6 +44,7 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -112,14 +114,43 @@ def train_fn(epoch, env_step):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

def save_train_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
pickle.dump(train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
else:
print("Fail to restore buffer.")

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_train_fn=save_train_fn)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -132,6 +163,11 @@ def test_fn(epoch, env_step):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_c51_resume(args=get_args()):
args.resume = True
test_c51(args)


def test_pc51(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
Expand Down
28 changes: 27 additions & 1 deletion test/discrete/test_il_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_args():
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume", action="store_true")
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -93,10 +94,30 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def save_train_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
# optim.load_state_dict(checkpoint['optim']) # don't know why
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")

result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_train_fn=save_train_fn)

assert stop_fn(result['best_reward'])

Expand All @@ -112,5 +133,10 @@ def stop_fn(mean_rewards):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_discrete_bcq_resume(args=get_args()):
args.resume = True
test_discrete_bcq(args)


if __name__ == "__main__":
test_discrete_bcq(get_args())
40 changes: 28 additions & 12 deletions tianshou/trainer/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def offline_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_train_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
epoch_per_save: int = 1,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
Expand All @@ -44,6 +47,14 @@ def offline_trainer(
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_train_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want. Because offline-RL doesn't have env_step, the env_step
is always 0 here.
:param int epoch_per_save: save train process each ``epoch_per_save`` epoch by
calling ``save_train_fn``. Default to 1.
:param bool resume_from_log: resume gradient_step and other metadata from existing
tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
Expand All @@ -59,15 +70,20 @@ def offline_trainer(

:return: See :func:`~tianshou.trainer.gather_info`.
"""
gradient_step = 0
start_epoch, gradient_step = 0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
test_collector.reset_stat()
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
logger, gradient_step, reward_metric)
best_epoch = 0
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
if resume_from_log:
best_epoch, best_reward, best_reward_std, start_epoch, _, \
gradient_step, _, _ = logger.restore_data()
else:
test_result = test_episode(policy, test_collector, test_fn, 0,
episode_per_test, logger, 0, reward_metric)
best_epoch = 0
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]

for epoch in range(1 + start_epoch, 1 + max_epoch):
policy.train()
with tqdm.trange(
update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
Expand All @@ -87,15 +103,15 @@ def offline_trainer(
policy, test_collector, test_fn, epoch, episode_per_test,
logger, gradient_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if epoch_per_save > 0 and epoch % epoch_per_save == 0 and save_train_fn:
save_train_fn(epoch, 0, gradient_step)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
45 changes: 31 additions & 14 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger


def offpolicy_trainer(
Expand All @@ -24,6 +24,9 @@ def offpolicy_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_train_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
epoch_per_save: int = 1,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
Expand Down Expand Up @@ -57,8 +60,15 @@ def offpolicy_trainer(
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy:BasePolicy) ->
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_train_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param int epoch_per_save: save train process each ``epoch_per_save`` epoch by
calling ``save_train_fn``. Default to 1.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
Expand All @@ -75,18 +85,23 @@ def offpolicy_trainer(

:return: See :func:`~tianshou.trainer.gather_info`.
"""
env_step, gradient_step = 0, 0
start_epoch, env_step, gradient_step = 0, 0, 0
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
logger, env_step, reward_metric)
best_epoch = 0
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
if resume_from_log:
best_epoch, best_reward, best_reward_std, start_epoch, env_step, \
gradient_step, last_rew, last_len = logger.restore_data()
else:
test_result = test_episode(policy, test_collector, test_fn, 0,
episode_per_test, logger, 0, reward_metric)
best_epoch = 0
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]

for epoch in range(1 + start_epoch, 1 + max_epoch):
# train
policy.train()
with tqdm.tqdm(
Expand Down Expand Up @@ -118,6 +133,8 @@ def offpolicy_trainer(
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
if save_train_fn:
save_train_fn(epoch, env_step, gradient_step)
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
Expand All @@ -139,15 +156,15 @@ def offpolicy_trainer(
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if epoch_per_save > 0 and epoch % epoch_per_save == 0 and save_train_fn:
save_train_fn(epoch, env_step, gradient_step)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,
Expand Down
Loading