Skip to content

Commit

Permalink
Add Trainers as generators (#559)
Browse files Browse the repository at this point in the history
The new proposed feature is to have trainers as generators.
The usage pattern is:

```python
trainer = OnPolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print(f"Epoch: {epoch}")
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
```

- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch, including stat
- info dict: the usual dict out of the non-generator version of the trainer

You can even iterate on several different trainers at the same time:

```python
trainer1 = OnPolicyTrainer(...)
trainer2 = OnPolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
```

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
  • Loading branch information
jamartinh and Trinkle23897 authored Mar 17, 2022
1 parent 2336a7d commit 10d9190
Show file tree
Hide file tree
Showing 14 changed files with 864 additions and 488 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
import tianshou, gym, torch, numpy, sys
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
```
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ lint:
flake8 ${LINT_PATHS} --count --show-source --statistics

format:
# sort imports
$(call check_install, isort)
isort ${LINT_PATHS}
# reformat using yapf
$(call check_install, yapf)
yapf -ir ${LINT_PATHS}

Expand Down Expand Up @@ -57,6 +55,6 @@ doc-clean:

clean: doc-clean

commit-checks: format lint mypy check-docstyle spelling
commit-checks: lint check-codestyle mypy check-docstyle spelling

.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
44 changes: 43 additions & 1 deletion docs/api/tianshou.trainer.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,49 @@
tianshou.trainer
================

.. automodule:: tianshou.trainer

On-policy
---------

.. autoclass:: tianshou.trainer.OnpolicyTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.onpolicy_trainer

.. autoclass:: tianshou.trainer.onpolicy_trainer_iter


Off-policy
----------

.. autoclass:: tianshou.trainer.OffpolicyTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.offpolicy_trainer

.. autoclass:: tianshou.trainer.offpolicy_trainer_iter


Offline
-------

.. autoclass:: tianshou.trainer.OfflineTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.offline_trainer

.. autoclass:: tianshou.trainer.offline_trainer_iter


utils
-----

.. autofunction:: tianshou.trainer.test_episode

.. autofunction:: tianshou.trainer.gather_info
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ fqf
iqn
qrdqn
rl
offpolicy
onpolicy
quantile
quantiles
dqn
param
async
subprocess
deque
nn
equ
cql
Expand Down
20 changes: 20 additions & 0 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho

Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.

We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
::

trainer = OnpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)

# or even iterate on several trainers at the same time

trainer1 = OnpolicyTrainer(...)
trainer2 = OnpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)


.. _pseudocode:

Expand Down
14 changes: 10 additions & 4 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
print("Fail to restore policy and optim.")

# trainer
result = onpolicy_trainer(
trainer = OnpolicyTrainer(
policy,
train_collector,
test_collector,
Expand All @@ -173,10 +173,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
)
assert stop_fn(result['best_reward'])

for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

assert stop_fn(info["best_reward"])

if __name__ == '__main__':
pprint.pprint(result)
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
Expand Down
2 changes: 1 addition & 1 deletion test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
Expand Down
19 changes: 12 additions & 7 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tianshou.env import DummyVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
Expand Down Expand Up @@ -135,8 +135,8 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

# trainer
result = offpolicy_trainer(
# Iterator trainer
trainer = OffpolicyTrainer(
policy,
train_collector,
test_collector,
Expand All @@ -148,12 +148,17 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

if __name__ == '__main__':
pprint.pprint(result)
assert stop_fn(info["best_reward"])

if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
Expand Down
16 changes: 11 additions & 5 deletions test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -195,7 +195,7 @@ def watch():
collector.collect(n_episode=1, render=1 / 35)

# trainer
result = offline_trainer(
trainer = OfflineTrainer(
policy,
buffer,
test_collector,
Expand All @@ -207,11 +207,17 @@ def watch():
stop_fn=stop_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])

for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

assert stop_fn(info["best_reward"])

# Let's watch its performance!
if __name__ == '__main__':
pprint.pprint(result)
if __name__ == "__main__":
pprint.pprint(info)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
Expand Down
30 changes: 24 additions & 6 deletions tianshou/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
"""Trainer package."""

# isort:skip_file

from tianshou.trainer.utils import test_episode, gather_info
from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
from tianshou.trainer.offline import offline_trainer
from tianshou.trainer.base import BaseTrainer
from tianshou.trainer.offline import (
OfflineTrainer,
offline_trainer,
offline_trainer_iter,
)
from tianshou.trainer.offpolicy import (
OffpolicyTrainer,
offpolicy_trainer,
offpolicy_trainer_iter,
)
from tianshou.trainer.onpolicy import (
OnpolicyTrainer,
onpolicy_trainer,
onpolicy_trainer_iter,
)
from tianshou.trainer.utils import gather_info, test_episode

__all__ = [
"BaseTrainer",
"offpolicy_trainer",
"offpolicy_trainer_iter",
"OffpolicyTrainer",
"onpolicy_trainer",
"onpolicy_trainer_iter",
"OnpolicyTrainer",
"offline_trainer",
"offline_trainer_iter",
"OfflineTrainer",
"test_episode",
"gather_info",
]
Loading

0 comments on commit 10d9190

Please sign in to comment.