Skip to content

Commit

Permalink
Adjust locations of setting the policy in train/eval mode (#1123)
Browse files Browse the repository at this point in the history
Addresses #1122:
* We Introduced a new flag `is_within_training_step` which is enabled by
the training algorithm when within a training step, where a training
step encompasses training data collection and policy updates. This flag
is now used by algorithms to decide whether their `deterministic_eval`
setting should indeed apply instead of the torch training flag (which
was abused!).
* The policy's training/eval mode (which should control torch-level
learning only) no longer needs to be set in user code in order to
control collector behaviour (this didn't make sense!). The respective
calls have been removed.
* The policy should, in fact, always be in evaluation mode when applying
data collection, as there is no reason to ever have gradient
accumulation enabled for any type of rollout. We thus specifically set
the policy to evaluation mode in Collector.collect. Further, it never
makes sense to compute gradients during collection, so the possibility
to pass `no_grad=False` was removed.

Further changes:
- Base class for collectors: `BaseCollector`
- New util context managers `in_eval_mode` and `in_train_mode` for torch
modules.
- `reset` of `Collectors` now returns `obs` and `info`. 
- `no-grad` no longer accepted as kwarg of `collect`
- Removed deprecations of `0.5.1` (will likely not affect anyone) and
the unused `warnings` module.
  • Loading branch information
MischaPanch authored May 6, 2024
2 parents 9fbf28e + e94a5c0 commit 26b867e
Show file tree
Hide file tree
Showing 103 changed files with 680 additions and 1,203 deletions.
34 changes: 23 additions & 11 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
- Add methods `to_numpy_` and `to_torch_`. #1098, #1117
- Add `__eq__` (semantic equality check). #1098
- `keys()` deprecated in favor of `get_keys()` (needed to make iteration consistent with naming) #1105.
- `data.collector`:
- `Collector`:
- Introduced `BaseCollector` as a base class for all collectors. #1123
- Add method `close` #1063
- Method `reset` is now more granular (new flags controlling behavior). #1063
- `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063
- `trainer`:
- Trainers can now control whether collectors should be reset prior to training. #1063
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
- policy:
- introduced attribute `in_training_step` that is controlled by the trainer. #1123
- policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123
- `highlevel`:
- `SamplingConfig`:
- Add support for `batch_size=None`. #1077
Expand All @@ -33,12 +37,14 @@
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `utils.net`:
- `continuous.Critic`:
- `utils`:
- `net.continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without
the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128
- `torch_utils` (new module)
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123

### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
Expand All @@ -65,20 +71,26 @@ instead of just `nn.Module`. #1032
- Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102

### Breaking Changes

- Removed `.data` attribute from `Collector` and its child classes. #1063
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- `data`:
- `Collector`:
- Removed `.data` attribute. #1063
- Collectors no longer reset the environment on initialization.
Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063
- Removed `no_grad` argument from `collect` method (was unused in tianshou). #1123
- `Batch`:
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`.
Can be considered a bugfix. #1063
- The methods `to_numpy` and `to_torch` in are not in-place anymore
(use `to_numpy_` or `to_torch_` instead). #1098, #1117
- Logging:
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- The methods `to_numpy` and `to_torch` in `Batch` is not in-place anymore (use `to_numpy_` or `to_torch_` instead). #1098, #1117
- `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074
- `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074


### Tests
Expand Down
4 changes: 1 addition & 3 deletions docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# !pip install tianshou gym"
Expand Down
29 changes: 16 additions & 13 deletions docs/02_notebooks/L4_Policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
")\n",
"from tianshou.utils import RunningMeanStd\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
]
},
{
Expand Down Expand Up @@ -644,7 +645,10 @@
"source": [
"obs, info = env.reset()\n",
"for i in range(3, 10):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" # For retrieving actions to be used for training, we set the policy to training mode,\n",
" # but the wrapped torch module should be in eval mode.\n",
" with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" obs_next, rew, _, truncated, info = env.step(act)\n",
" # pretend this episode never end\n",
" terminated = False\n",
Expand Down Expand Up @@ -695,7 +699,11 @@
},
"source": [
"#### Updates\n",
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train."
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n",
"\n",
"However, we need to manually set the torch module to training mode prior to that, \n",
"and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n",
"but users need to consider it when calling `.update` outside of the trainer."
]
},
{
Expand All @@ -711,16 +719,11 @@
"outputs": [],
"source": [
"# 0 means sample all data from the buffer\n",
"policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "enqlFQLSJrQl"
},
"source": [
"Not that difficult, right?"
"\n",
"# For updating the policy, the policy should be in training mode\n",
"# and the wrapped torch module should also be in training mode (unlike when collecting data).\n",
"with policy_within_training_step(policy), torch_train_mode(policy):\n",
" policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
]
},
{
Expand Down
44 changes: 31 additions & 13 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"id": "do-xZ-8B7nVH",
Expand All @@ -64,9 +63,12 @@
"tags": [
"hide-cell",
"remove-output"
]
],
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
}
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
Expand All @@ -78,14 +80,20 @@
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
]
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
Expand Down Expand Up @@ -123,7 +131,9 @@
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
]
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -154,11 +164,19 @@
"\n",
"n_episode = 10\n",
"for _i in range(n_episode):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" # for test collector, we set the wrapped torch module to evaluation mode\n",
" # by default, the policy object itself is not within the training step\n",
" with torch_train_mode(policy, enabled=False):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" # for collecting data for training, the policy object should be within the training step\n",
" # (affecting e.g. whether the policy is stochastic or deterministic)\n",
" with policy_within_training_step(policy):\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" # for updating the policy, the wrapped torch module should be in training mode\n",
" with torch_train_mode(policy):\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" train_collector.reset_buffer(keep_statistics=True)"
]
},
Expand Down
4 changes: 3 additions & 1 deletion docs/autogen_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
subdir_refs = [
f"{f}/index"
for f in files_in_dir
if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_")
if os.path.isdir(os.path.join(src_root, f))
and not f.startswith("_")
and not f.startswith(".")
]
package_index_rst_path = os.path.join(
rst_root,
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/bipedal_bdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def stop_fn(mean_rewards: float) -> bool:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def stop_fn(mean_rewards: float) -> bool:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/discrete/discrete_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def stop_fn(mean_rewards: float) -> bool:
print(f"Finished training in {result.timing.total_time} seconds")

# watch performance
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=100, render=1 / 35)
Expand Down
1 change: 0 additions & 1 deletion examples/discrete/discrete_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def main() -> None:
DQNExperimentBuilder(
EnvFactoryRegistered(
task="CartPole-v1",
seed=0,
venv_type=VectorEnvType.DUMMY,
train_seed=0,
test_seed=10,
Expand Down
1 change: 0 additions & 1 deletion examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/fetch_her_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
Loading

0 comments on commit 26b867e

Please sign in to comment.