Skip to content

Commit

Permalink
Add show_progress option for trainer (#641)
Browse files Browse the repository at this point in the history
- A DummyTqdm class added to utils: it replicates the interface used by trainers, but does not show the progress bar;
- Added a show_progress argument to the base trainer: when show_progress == True, dummy_tqdm is used in place of tqdm.
  • Loading branch information
michalgregor authored May 17, 2022
1 parent 53e6b04 commit c87b9f4
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 15 deletions.
4 changes: 2 additions & 2 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import gym
import numpy as np

from tianshou.env import ShmemVectorEnv

try:
import envpool
except ImportError:
envpool = None

from tianshou.env import ShmemVectorEnv


class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
Expand Down
8 changes: 4 additions & 4 deletions examples/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import warnings

import gym

from tianshou.env import ShmemVectorEnv, VectorEnvNormObs

try:
import envpool
except ImportError:
envpool = None

import gym

from tianshou.env import ShmemVectorEnv, VectorEnvNormObs


def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
"""Wrapper function for Mujoco env.
Expand Down
4 changes: 2 additions & 2 deletions test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_pg(args=get_args()):
dist,
args.gamma,
reward_normalization=args.rew_norm,
action_space=env.action_space
action_space=env.action_space,
)
for m in net.modules():
if isinstance(m, torch.nn.Linear):
Expand Down Expand Up @@ -116,7 +116,7 @@ def stop_fn(mean_rewards):
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])

Expand Down
2 changes: 2 additions & 0 deletions test/offline/test_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_args():
help='watch the play of pre-trained policy only',
)
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
parser.add_argument("--show-progress", action="store_true")
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -209,6 +210,7 @@ def watch():
save_best_fn=save_best_fn,
stop_fn=stop_fn,
logger=logger,
show_progress=args.show_progress,
)
assert stop_fn(result['best_reward'])

Expand Down
21 changes: 19 additions & 2 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.trainer.utils import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, deprecation, tqdm_config
from tianshou.utils import (
BaseLogger,
DummyTqdm,
LazyLogger,
MovAvg,
deprecation,
tqdm_config,
)


class BaseTrainer(ABC):
Expand Down Expand Up @@ -68,6 +75,8 @@ class BaseTrainer(ABC):
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
Expand Down Expand Up @@ -143,6 +152,7 @@ def __init__(
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
):
Expand Down Expand Up @@ -190,6 +200,7 @@ def __init__(

self.reward_metric = reward_metric
self.verbose = verbose
self.show_progress = show_progress
self.test_in_train = test_in_train
self.resume_from_log = resume_from_log

Expand Down Expand Up @@ -259,8 +270,14 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]:
self.policy.train()

epoch_stat: Dict[str, Any] = dict()

if self.show_progress:
progress = tqdm.tqdm
else:
progress = DummyTqdm

# perform n step_per_epoch
with tqdm.tqdm(
with progress(
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
) as t:
while t.n < t.total and not self.stop_fn_flag:
Expand Down
4 changes: 4 additions & 0 deletions tianshou/trainer/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class OfflineTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during
updating/testing. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
"""

__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
Expand All @@ -70,6 +72,7 @@ def __init__(
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
**kwargs: Any,
):
super().__init__(
Expand All @@ -90,6 +93,7 @@ def __init__(
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
show_progress=show_progress,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class OffpolicyTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
Expand All @@ -83,6 +85,7 @@ def __init__(
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True,
**kwargs: Any,
):
Expand All @@ -106,6 +109,7 @@ def __init__(
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
show_progress=show_progress,
test_in_train=test_in_train,
**kwargs,
)
Expand Down
4 changes: 4 additions & 0 deletions tianshou/trainer/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class OnpolicyTrainer(BaseTrainer):
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool show_progress: whether to display a progress bar when training.
Default to True.
:param bool test_in_train: whether to test in the training phase. Default to
True.
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
show_progress: bool = True,
test_in_train: bool = True,
**kwargs: Any,
):
Expand All @@ -115,6 +118,7 @@ def __init__(
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
show_progress=show_progress,
test_in_train=test_in_train,
**kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion tianshou/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Utils package."""

from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.warning import deprecation

__all__ = [
"MovAvg",
"RunningMeanStd",
"tqdm_config",
"DummyTqdm",
"BaseLogger",
"TensorboardLogger",
"BasicLogger",
Expand Down
4 changes: 0 additions & 4 deletions tianshou/utils/config.py

This file was deleted.

35 changes: 35 additions & 0 deletions tianshou/utils/progress_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any

tqdm_config = {
"dynamic_ncols": True,
"ascii": True,
}


class DummyTqdm:
"""A dummy tqdm class that keeps stats but without progress bar.
It supports ``__enter__`` and ``__exit__``, update and a dummy
``set_postfix``, which is the interface that trainers use.
.. note::
Using ``disable=True`` in tqdm config results in infinite loop, thus
this class is created. See the discussion at #641 for details.
"""

def __init__(self, total: int, **kwargs: Any):
self.total = total
self.n = 0

def set_postfix(self, **kwargs: Any) -> None:
pass

def update(self, n: int = 1) -> None:
self.n += n

def __enter__(self) -> "DummyTqdm":
return self

def __exit__(self, *args: Any, **kwargs: Any) -> None:
pass

0 comments on commit c87b9f4

Please sign in to comment.