diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index e1488688c..c5c07bf3f 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -7,6 +7,6 @@ - [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python - import tianshou, torch, numpy, sys - print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) + import tianshou, gym, torch, numpy, sys + print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) ``` diff --git a/Makefile b/Makefile index 616c4c983..b4c4a623e 100644 --- a/Makefile +++ b/Makefile @@ -22,10 +22,8 @@ lint: flake8 ${LINT_PATHS} --count --show-source --statistics format: - # sort imports $(call check_install, isort) isort ${LINT_PATHS} - # reformat using yapf $(call check_install, yapf) yapf -ir ${LINT_PATHS} @@ -57,6 +55,6 @@ doc-clean: clean: doc-clean -commit-checks: format lint mypy check-docstyle spelling +commit-checks: lint check-codestyle mypy check-docstyle spelling .PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks diff --git a/docs/api/tianshou.trainer.rst b/docs/api/tianshou.trainer.rst index 9deed5053..13c6d66c9 100644 --- a/docs/api/tianshou.trainer.rst +++ b/docs/api/tianshou.trainer.rst @@ -1,7 +1,49 @@ tianshou.trainer ================ -.. automodule:: tianshou.trainer + +On-policy +--------- + +.. autoclass:: tianshou.trainer.OnpolicyTrainer + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: tianshou.trainer.onpolicy_trainer + +.. autoclass:: tianshou.trainer.onpolicy_trainer_iter + + +Off-policy +---------- + +.. autoclass:: tianshou.trainer.OffpolicyTrainer + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: tianshou.trainer.offpolicy_trainer + +.. autoclass:: tianshou.trainer.offpolicy_trainer_iter + + +Offline +------- + +.. autoclass:: tianshou.trainer.OfflineTrainer :members: :undoc-members: :show-inheritance: + +.. autofunction:: tianshou.trainer.offline_trainer + +.. autoclass:: tianshou.trainer.offline_trainer_iter + + +utils +----- + +.. autofunction:: tianshou.trainer.test_episode + +.. autofunction:: tianshou.trainer.gather_info diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 86bdb281e..820fb8867 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -24,12 +24,15 @@ fqf iqn qrdqn rl +offpolicy +onpolicy quantile quantiles dqn param async subprocess +deque nn equ cql diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index d500787ee..cb6d616fe 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage. +We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic: +:: + + trainer = OnpolicyTrainer(...) + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + do_something_with_policy() + query_something_about_policy() + make_a_plot_with(epoch_stat) + display(info) + + # or even iterate on several trainers at the same time + + trainer1 = OnpolicyTrainer(...) + trainer2 = OnpolicyTrainer(...) + for result1, result2, ... in zip(trainer1, trainer2, ...): + compare_results(result1, result2, ...) + .. _pseudocode: diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 9f4af9bb0..f187b0f68 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy -from tianshou.trainer import onpolicy_trainer +from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print("Fail to restore policy and optim.") # trainer - result = onpolicy_trainer( + trainer = OnpolicyTrainer( policy, train_collector, test_collector, @@ -173,10 +173,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn ) - assert stop_fn(result['best_reward']) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + assert stop_fn(info["best_reward"]) if __name__ == '__main__': - pprint.pprint(result) + pprint.pprint(info) # Let's watch its performance! env = gym.make(args.task) policy.eval() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b2287a2c5..8930b0839 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -24,7 +24,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--reward-threshold', type=float, default=None) - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 64df9c418..d7ee186fc 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy -from tianshou.trainer import offpolicy_trainer +from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -135,8 +135,8 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold - # trainer - result = offpolicy_trainer( + # Iterator trainer + trainer = OffpolicyTrainer( policy, train_collector, test_collector, @@ -148,12 +148,17 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) - assert stop_fn(result['best_reward']) + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) - if __name__ == '__main__': - pprint.pprint(result) + assert stop_fn(info["best_reward"]) + + if __name__ == "__main__": + pprint.pprint(info) # Let's watch its performance! env = gym.make(args.task) policy.eval() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 0969cb951..91a2784df 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import CQLPolicy -from tianshou.trainer import offline_trainer +from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -195,7 +195,7 @@ def watch(): collector.collect(n_episode=1, render=1 / 35) # trainer - result = offline_trainer( + trainer = OfflineTrainer( policy, buffer, test_collector, @@ -207,11 +207,17 @@ def watch(): stop_fn=stop_fn, logger=logger, ) - assert stop_fn(result['best_reward']) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + assert stop_fn(info["best_reward"]) # Let's watch its performance! - if __name__ == '__main__': - pprint.pprint(result) + if __name__ == "__main__": + pprint.pprint(info) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 11b3a95ef..8f1361bec 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,16 +1,34 @@ """Trainer package.""" -# isort:skip_file - -from tianshou.trainer.utils import test_episode, gather_info -from tianshou.trainer.onpolicy import onpolicy_trainer -from tianshou.trainer.offpolicy import offpolicy_trainer -from tianshou.trainer.offline import offline_trainer +from tianshou.trainer.base import BaseTrainer +from tianshou.trainer.offline import ( + OfflineTrainer, + offline_trainer, + offline_trainer_iter, +) +from tianshou.trainer.offpolicy import ( + OffpolicyTrainer, + offpolicy_trainer, + offpolicy_trainer_iter, +) +from tianshou.trainer.onpolicy import ( + OnpolicyTrainer, + onpolicy_trainer, + onpolicy_trainer_iter, +) +from tianshou.trainer.utils import gather_info, test_episode __all__ = [ + "BaseTrainer", "offpolicy_trainer", + "offpolicy_trainer_iter", + "OffpolicyTrainer", "onpolicy_trainer", + "onpolicy_trainer_iter", + "OnpolicyTrainer", "offline_trainer", + "offline_trainer_iter", + "OfflineTrainer", "test_episode", "gather_info", ] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py new file mode 100644 index 000000000..fca1036ff --- /dev/null +++ b/tianshou/trainer/base.py @@ -0,0 +1,419 @@ +import time +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union + +import numpy as np +import tqdm + +from tianshou.data import Collector, ReplayBuffer +from tianshou.policy import BasePolicy +from tianshou.trainer.utils import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config + + +class BaseTrainer(ABC): + """An iterator base class for trainers procedure. + + Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + :param learning_type str: type of learning iterator, available choices are + "offpolicy", "onpolicy" and "offline". + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` + is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with + the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; + you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray + with shape (num_episode,)``, used in multi-agent RL. We need to return a + single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. + Default to True. + """ + + @staticmethod + def gen_doc(learning_type: str) -> str: + """Document string for subclass trainer.""" + step_means = f'The "step" in {learning_type} trainer means ' + if learning_type != "offline": + step_means += "an environment step (a.k.a. transition)." + else: # offline + step_means += "a gradient step." + + trainer_name = learning_type.capitalize() + "Trainer" + + return f"""An iterator class for {learning_type} trainer procedure. + + Returns an iterator that yields a 3-tuple (epoch, stats, info) of + train results on every epoch. + + {step_means} + + Example usage: + + :: + + trainer = {trainer_name}(...) + for epoch, epoch_stat, info in trainer: + print("Epoch:", epoch) + print(epoch_stat) + print(info) + do_something_with_policy() + query_something_about_policy() + make_a_plot_with(epoch_stat) + display(info) + + - epoch int: the epoch number + - epoch_stat dict: a large collection of metrics of the current epoch + - info dict: result returned from :func:`~tianshou.trainer.gather_info` + + You can even iterate on several trainers at the same time: + + :: + + trainer1 = {trainer_name}(...) + trainer2 = {trainer_name}(...) + for result1, result2, ... in zip(trainer1, trainer2, ...): + compare_results(result1, result2, ...) + """ + + def __init__( + self, + learning_type: str, + policy: BasePolicy, + max_epoch: int, + batch_size: int, + train_collector: Optional[Collector] = None, + test_collector: Optional[Collector] = None, + buffer: Optional[ReplayBuffer] = None, + step_per_epoch: Optional[int] = None, + repeat_per_collect: Optional[int] = None, + episode_per_test: Optional[int] = None, + update_per_step: Union[int, float] = 1, + update_per_epoch: Optional[int] = None, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, + ): + self.policy = policy + self.buffer = buffer + + self.train_collector = train_collector + self.test_collector = test_collector + + self.logger = logger + self.start_time = time.time() + self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg) + self.best_reward = 0.0 + self.best_reward_std = 0.0 + self.start_epoch = 0 + self.gradient_step = 0 + self.env_step = 0 + self.max_epoch = max_epoch + self.step_per_epoch = step_per_epoch + + # either on of these two + self.step_per_collect = step_per_collect + self.episode_per_collect = episode_per_collect + + self.update_per_step = update_per_step + self.repeat_per_collect = repeat_per_collect + + self.episode_per_test = episode_per_test + + self.batch_size = batch_size + + self.train_fn = train_fn + self.test_fn = test_fn + self.stop_fn = stop_fn + self.save_fn = save_fn + self.save_checkpoint_fn = save_checkpoint_fn + + self.reward_metric = reward_metric + self.verbose = verbose + self.test_in_train = test_in_train + self.resume_from_log = resume_from_log + + self.is_run = False + self.last_rew, self.last_len = 0.0, 0 + + self.epoch = self.start_epoch + self.best_epoch = self.start_epoch + self.stop_fn_flag = False + self.iter_num = 0 + + def reset(self) -> None: + """Initialize or reset the instance to yield a new iterator from zero.""" + self.is_run = False + self.env_step = 0 + if self.resume_from_log: + self.start_epoch, self.env_step, self.gradient_step = \ + self.logger.restore_data() + + self.last_rew, self.last_len = 0.0, 0 + self.start_time = time.time() + if self.train_collector is not None: + self.train_collector.reset_stat() + + if self.train_collector.policy != self.policy: + self.test_in_train = False + elif self.test_collector is None: + self.test_in_train = False + + if self.test_collector is not None: + assert self.episode_per_test is not None + self.test_collector.reset_stat() + test_result = test_episode( + self.policy, self.test_collector, self.test_fn, self.start_epoch, + self.episode_per_test, self.logger, self.env_step, self.reward_metric + ) + self.best_epoch = self.start_epoch + self.best_reward, self.best_reward_std = \ + test_result["rew"], test_result["rew_std"] + if self.save_fn: + self.save_fn(self.policy) + + self.epoch = self.start_epoch + self.stop_fn_flag = False + self.iter_num = 0 + + def __iter__(self): # type: ignore + self.reset() + return self + + def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: + """Perform one epoch (both train and eval).""" + self.epoch += 1 + self.iter_num += 1 + + if self.iter_num > 1: + + # iterator exhaustion check + if self.epoch >= self.max_epoch: + if self.test_collector is None and self.save_fn: + self.save_fn(self.policy) + raise StopIteration + + # exit flag 1, when stop_fn succeeds in train_step or test_step + if self.stop_fn_flag: + raise StopIteration + + # set policy in train mode + self.policy.train() + + epoch_stat: Dict[str, Any] = dict() + # perform n step_per_epoch + with tqdm.tqdm( + total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config + ) as t: + while t.n < t.total and not self.stop_fn_flag: + data: Dict[str, Any] = dict() + result: Dict[str, Any] = dict() + if self.train_collector is not None: + data, result, self.stop_fn_flag = self.train_step() + t.update(result["n/st"]) + if self.stop_fn_flag: + t.set_postfix(**data) + break + else: + assert self.buffer, "No train_collector or buffer specified" + result["n/ep"] = len(self.buffer) + result["n/st"] = int(self.gradient_step) + t.update() + + self.policy_update_fn(data, result) + t.set_postfix(**data) + + if t.n <= t.total and not self.stop_fn_flag: + t.update() + + if not self.stop_fn_flag: + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn + ) + # test + if self.test_collector is not None: + test_stat, self.stop_fn_flag = self.test_step() + if not self.is_run: + epoch_stat.update(test_stat) + + if not self.is_run: + epoch_stat.update({k: v.get() for k, v in self.stat.items()}) + epoch_stat["gradient_step"] = self.gradient_step + epoch_stat.update( + { + "env_step": self.env_step, + "rew": self.last_rew, + "len": int(self.last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) + info = gather_info( + self.start_time, self.train_collector, self.test_collector, + self.best_reward, self.best_reward_std + ) + return self.epoch, epoch_stat, info + else: + return None + + def test_step(self) -> Tuple[Dict[str, Any], bool]: + """Perform one testing step.""" + assert self.episode_per_test is not None + assert self.test_collector is not None + stop_fn_flag = False + test_result = test_episode( + self.policy, self.test_collector, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.env_step, self.reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if self.best_epoch < 0 or self.best_reward < rew: + self.best_epoch = self.epoch + self.best_reward = float(rew) + self.best_reward_std = rew_std + if self.save_fn: + self.save_fn(self.policy) + if self.verbose: + print( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) + if not self.is_run: + test_stat = { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": self.best_reward, + "best_reward_std": self.best_reward_std, + "best_epoch": self.best_epoch + } + else: + test_stat = {} + if self.stop_fn and self.stop_fn(self.best_reward): + stop_fn_flag = True + + return test_stat, stop_fn_flag + + def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: + """Perform one training step.""" + assert self.episode_per_test is not None + assert self.train_collector is not None + stop_fn_flag = False + if self.train_fn: + self.train_fn(self.epoch, self.env_step) + result = self.train_collector.collect( + n_step=self.step_per_collect, n_episode=self.episode_per_collect + ) + if result["n/ep"] > 0 and self.reward_metric: + rew = self.reward_metric(result["rews"]) + result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + self.env_step += int(result["n/st"]) + self.logger.log_train_data(result, self.env_step) + self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew + self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len + data = { + "env_step": str(self.env_step), + "rew": f"{self.last_rew:.2f}", + "len": str(int(self.last_len)), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + } + if result["n/ep"] > 0: + if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]): + assert self.test_collector is not None + test_result = test_episode( + self.policy, self.test_collector, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.env_step + ) + if self.stop_fn(test_result["rew"]): + stop_fn_flag = True + self.best_reward = test_result["rew"] + self.best_reward_std = test_result["rew_std"] + else: + self.policy.train() + + return data, result, stop_fn_flag + + def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None: + """Log losses to current logger.""" + for k in losses.keys(): + self.stat[k].add(losses[k]) + losses[k] = self.stat[k].get() + data[k] = f"{losses[k]:.3f}" + self.logger.log_update_data(losses, self.gradient_step) + + @abstractmethod + def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: + """Policy update function for different trainer implementation. + + :param data: information in progress bar. + :param result: collector's return value. + """ + + def run(self) -> Dict[str, Union[float, str]]: + """Consume iterator. + + See itertools - recipes. Use functions that consume iterators at C speed + (feed the entire iterator into a zero-length deque). + """ + try: + self.is_run = True + deque(self, maxlen=0) # feed the entire iterator into a zero-length deque + info = gather_info( + self.start_time, self.train_collector, self.test_collector, + self.best_reward, self.best_reward_std + ) + finally: + self.is_run = False + + return info diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index d2f85bc2a..890429a8b 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,131 +1,115 @@ -import time -from collections import defaultdict -from typing import Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np -import tqdm from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config +from tianshou.trainer.base import BaseTrainer +from tianshou.utils import BaseLogger, LazyLogger -def offline_trainer( - policy: BasePolicy, - buffer: ReplayBuffer, - test_collector: Optional[Collector], - max_epoch: int, - update_per_epoch: int, - episode_per_test: int, - batch_size: int, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, -) -> Dict[str, Union[float, str]]: - """A wrapper for offline trainer procedure. - - The "step" in offline trainer means a gradient step. +class OfflineTrainer(BaseTrainer): + """Create an iterator class for offline training procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + This buffer must be populated with experiences for offline RL. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. Because offline-RL doesn't have env_step, the env_step - is always 0 here. - :param bool resume_from_log: resume gradient_step and other metadata from existing - tensorboard log. Default to False. + :param function test_fn: a hook called at the beginning of testing in each + epoch. + It can be used to perform custom additional operations, with the signature + ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, + with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> + None``; you can save whatever you want. Because offline-RL doesn't have + env_step, the env_step is always 0 here. + :param bool resume_from_log: resume gradient_step and other metadata from + existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during updating/testing. - Default to a logger that doesn't log anything. + :param function reward_metric: a function with signature ``f(rewards: + np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape + (num_episode,)``, used in multi-agent RL. We need to return a single scalar + for each episode's result to monitor training in the multi-agent RL + setting. This function specifies what is the desired metric, e.g., the + reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. """ - start_epoch, gradient_step = 0, 0 - if resume_from_log: - start_epoch, _, gradient_step = logger.restore_data() - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - if test_collector is not None: - test_c: Collector = test_collector - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, - gradient_step, reward_metric + __doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:]) + + def __init__( + self, + policy: BasePolicy, + buffer: ReplayBuffer, + test_collector: Optional[Collector], + max_epoch: int, + update_per_epoch: int, + episode_per_test: int, + batch_size: int, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + ): + super().__init__( + learning_type="offline", + policy=policy, + buffer=buffer, + test_collector=test_collector, + max_epoch=max_epoch, + update_per_epoch=update_per_epoch, + step_per_epoch=update_per_epoch, + episode_per_test=episode_per_test, + batch_size=batch_size, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + save_checkpoint_fn=save_checkpoint_fn, + resume_from_log=resume_from_log, + reward_metric=reward_metric, + logger=logger, + verbose=verbose, ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - for epoch in range(1 + start_epoch, 1 + max_epoch): - policy.train() - with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: - for _ in t: - gradient_step += 1 - losses = policy.update(batch_size, buffer) - data = {"gradient_step": str(gradient_step)} - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - # test - if test_collector is not None: - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - if stop_fn and stop_fn(best_reward): - break + def policy_update_fn( + self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None + ) -> None: + """Perform one off-line policy update.""" + assert self.buffer + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.buffer) + data.update({"gradient_step": str(self.gradient_step)}) + self.log_update_data(data, losses) - if test_collector is None and save_fn: - save_fn(policy) - if test_collector is None: - return gather_info(start_time, None, None, 0.0, 0.0) - else: - return gather_info( - start_time, None, test_collector, best_reward, best_reward_std - ) +def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + """Wrapper for offline_trainer run method. + + It is identical to ``OfflineTrainer(...).run()``. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + return OfflineTrainer(*args, **kwargs).run() + + +offline_trainer_iter = OfflineTrainer diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 9b8727b24..c3580397a 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,193 +1,130 @@ -import time -from collections import defaultdict -from typing import Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np -import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config +from tianshou.trainer.base import BaseTrainer +from tianshou.utils import BaseLogger, LazyLogger -def offpolicy_trainer( - policy: BasePolicy, - train_collector: Collector, - test_collector: Optional[Collector], - max_epoch: int, - step_per_epoch: int, - step_per_collect: int, - episode_per_test: int, - batch_size: int, - update_per_step: Union[int, float] = 1, - train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, - test_in_train: bool = True, -) -> Dict[str, Union[float, str]]: - """A wrapper for off-policy trainer procedure. - - The "step" in trainer means an environment step (a.k.a. transition). +class OffpolicyTrainer(BaseTrainer): + """Create an iterator wrapper for off-policy training procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. :param int step_per_epoch: the number of transitions collected per epoch. - :param int step_per_collect: the number of transitions the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatedly in each epoch. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in the - policy network. - :param int/float update_per_step: the number of times the policy network would be - updated per transition after (step_per_collect) transitions are collected, - e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will - be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are - collected by the collector. Default to 1. - :param function train_fn: a hook called at the beginning of training in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata from - existing tensorboard log. Default to False. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int/float update_per_step: the number of times the policy network would + be updated per transition after (step_per_collect) transitions are + collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256 + , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 + transitions are collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with + the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; + you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to + return a single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. :param BaseLogger logger: A logger that logs statistics during training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. + :param bool test_in_train: whether to test in the training phase. + Default to True. """ - start_epoch, env_step, gradient_step = 0, 0, 0 - if resume_from_log: - start_epoch, env_step, gradient_step = logger.restore_data() - last_rew, last_len = 0.0, 0 - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - train_collector.reset_stat() - test_in_train = test_in_train and ( - train_collector.policy == policy and test_collector is not None - ) - if test_collector is not None: - test_c: Collector = test_collector # for mypy - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, - reward_metric + __doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:]) + + def __init__( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Optional[Collector], + max_epoch: int, + step_per_epoch: int, + step_per_collect: int, + episode_per_test: int, + batch_size: int, + update_per_step: Union[int, float] = 1, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, + ): + super().__init__( + learning_type="offpolicy", + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=max_epoch, + step_per_epoch=step_per_epoch, + step_per_collect=step_per_collect, + episode_per_test=episode_per_test, + batch_size=batch_size, + update_per_step=update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + save_checkpoint_fn=save_checkpoint_fn, + resume_from_log=resume_from_log, + reward_metric=reward_metric, + logger=logger, + verbose=verbose, + test_in_train=test_in_train, ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - for epoch in range(1 + start_epoch, 1 + max_epoch): - # train - policy.train() - with tqdm.tqdm( - total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: - while t.n < t.total: - if train_fn: - train_fn(epoch, env_step) - result = train_collector.collect(n_step=step_per_collect) - if result["n/ep"] > 0 and reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - env_step += int(result["n/st"]) - t.update(result["n/st"]) - logger.log_train_data(result, env_step) - last_rew = result['rew'] if result["n/ep"] > 0 else last_rew - last_len = result['len'] if result["n/ep"] > 0 else last_len - data = { - "env_step": str(env_step), - "rew": f"{last_rew:.2f}", - "len": str(int(last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), - } - if result["n/ep"] > 0: - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step - ) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn - ) - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"] - ) - else: - policy.train() - for _ in range(round(update_per_step * result["n/st"])): - gradient_step += 1 - losses = policy.update(batch_size, train_collector.buffer) - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - if t.n <= t.total: - t.update() - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # test - if test_collector is not None: - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - if stop_fn and stop_fn(best_reward): - break + def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: + """Perform off-policy updates.""" + assert self.train_collector is not None + for _ in range(round(self.update_per_step * result["n/st"])): + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.train_collector.buffer) + self.log_update_data(data, losses) - if test_collector is None and save_fn: - save_fn(policy) - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) +def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + """Wrapper for OffPolicyTrainer run method. + + It is identical to ``OffpolicyTrainer(...).run()``. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + return OffpolicyTrainer(*args, **kwargs).run() + + +offpolicy_trainer_iter = OffpolicyTrainer diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 251c55637..46b195a70 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,209 +1,147 @@ -import time -from collections import defaultdict -from typing import Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np -import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config - - -def onpolicy_trainer( - policy: BasePolicy, - train_collector: Collector, - test_collector: Optional[Collector], - max_epoch: int, - step_per_epoch: int, - repeat_per_collect: int, - episode_per_test: int, - batch_size: int, - step_per_collect: Optional[int] = None, - episode_per_collect: Optional[int] = None, - train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, - test_in_train: bool = True, -) -> Dict[str, Union[float, str]]: - """A wrapper for on-policy trainer procedure. - - The "step" in trainer means an environment step (a.k.a. transition). +from tianshou.trainer.base import BaseTrainer +from tianshou.utils import BaseLogger, LazyLogger + + +class OnpolicyTrainer(BaseTrainer): + """Create an iterator wrapper for on-policy training procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. :param int step_per_epoch: the number of transitions collected per epoch. - :param int repeat_per_collect: the number of repeat time for policy learning, for - example, set it to 2 means the policy needs to learn each given batch data - twice. + :param int repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. :param int episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in the - policy network. - :param int step_per_collect: the number of transitions the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatedly in each epoch. - :param int episode_per_collect: the number of episodes the collector would collect - before the network update, i.e., trainer will collect "episode_per_collect" - episodes and do some policy network update repeatedly in each epoch. - :param function train_fn: a hook called at the beginning of training in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata from - existing tensorboard log. Default to False. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, + with the signature ``f(epoch: int, env_step: int, gradient_step: int) + -> None``; you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + np.ndarray with shape (num_episode,)``, used in multi-agent RL. + We need to return a single scalar for each episode's result to monitor + training in the multi-agent RL setting. This function specifies what is the + desired metric, e.g., the reward of agent 1 or the average reward over + all agents. :param BaseLogger logger: A logger that logs statistics during training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. + :param bool test_in_train: whether to test in the training phase. Default to + True. .. note:: Only either one of step_per_collect and episode_per_collect can be specified. """ - start_epoch, env_step, gradient_step = 0, 0, 0 - if resume_from_log: - start_epoch, env_step, gradient_step = logger.restore_data() - last_rew, last_len = 0.0, 0 - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - train_collector.reset_stat() - test_in_train = test_in_train and ( - train_collector.policy == policy and test_collector is not None - ) - - if test_collector is not None: - test_c: Collector = test_collector # for mypy - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, - reward_metric + + __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:]) + + def __init__( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Optional[Collector], + max_epoch: int, + step_per_epoch: int, + repeat_per_collect: int, + episode_per_test: int, + batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, + ): + super().__init__( + learning_type="onpolicy", + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=max_epoch, + step_per_epoch=step_per_epoch, + repeat_per_collect=repeat_per_collect, + episode_per_test=episode_per_test, + batch_size=batch_size, + step_per_collect=step_per_collect, + episode_per_collect=episode_per_collect, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + save_checkpoint_fn=save_checkpoint_fn, + resume_from_log=resume_from_log, + reward_metric=reward_metric, + logger=logger, + verbose=verbose, + test_in_train=test_in_train, ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - - for epoch in range(1 + start_epoch, 1 + max_epoch): - # train - policy.train() - with tqdm.tqdm( - total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: - while t.n < t.total: - if train_fn: - train_fn(epoch, env_step) - result = train_collector.collect( - n_step=step_per_collect, n_episode=episode_per_collect - ) - if result["n/ep"] > 0 and reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - env_step += int(result["n/st"]) - t.update(result["n/st"]) - logger.log_train_data(result, env_step) - last_rew = result['rew'] if result["n/ep"] > 0 else last_rew - last_len = result['len'] if result["n/ep"] > 0 else last_len - data = { - "env_step": str(env_step), - "rew": f"{last_rew:.2f}", - "len": str(int(last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), - } - if result["n/ep"] > 0: - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step - ) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn - ) - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"] - ) - else: - policy.train() - losses = policy.update( - 0, - train_collector.buffer, - batch_size=batch_size, - repeat=repeat_per_collect - ) - train_collector.reset_buffer(keep_statistics=True) - step = max( - [1] + [len(v) for v in losses.values() if isinstance(v, list)] - ) - gradient_step += step - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - if t.n <= t.total: - t.update() - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # test - if test_collector is not None: - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - if stop_fn and stop_fn(best_reward): - break - - if test_collector is None and save_fn: - save_fn(policy) - - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std + + def policy_update_fn( + self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None + ) -> None: + """Perform one on-policy update.""" + assert self.train_collector is not None + losses = self.policy.update( + 0, + self.train_collector.buffer, + batch_size=self.batch_size, + repeat=self.repeat_per_collect, ) + self.train_collector.reset_buffer(keep_statistics=True) + step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)]) + self.gradient_step += step + self.log_update_data(data, losses) + + +def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + """Wrapper for OnpolicyTrainer run method. + + It is identical to ``OnpolicyTrainer(...).run()``. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + return OnpolicyTrainer(*args, **kwargs).run() + + +onpolicy_trainer_iter = OnpolicyTrainer