Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust locations of setting the policy in train/eval mode #1123

Merged
merged 35 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e499bed
add is_eval attribute to policy and set this attribute as well as tra…
Apr 24, 2024
8cb17de
update examples
Apr 24, 2024
49c750f
update tests
Apr 24, 2024
829fd9c
Deleted long deprecated functionality, removed unused warning module
Apr 26, 2024
7d59302
Added in_eval/in_train mode contextmanager
Apr 26, 2024
12d4262
Tests: removed all instances of `if __name__ == ...` in tests
Apr 26, 2024
4b619c5
Collector: extracted interface BaseCollector, minor simplifications
Apr 26, 2024
69f07a8
Tests: fixed typing issues by declaring union types and no longer reu…
Apr 26, 2024
07a97c7
Merge branch 'refs/heads/thuml-master' into policy-train-eval
Apr 26, 2024
2eaf1f3
Use the new BaseCollector interface for annotations
Apr 26, 2024
c28508b
Changelog
Apr 26, 2024
6aa33b1
Formatting
Apr 26, 2024
e2e8a69
Changelog [skip-ci]
Apr 26, 2024
4592271
Dosctring add return [skip-ci]
Apr 26, 2024
a2b9d7c
Changelog [skip-ci]
Apr 26, 2024
4f16494
Set torch train mode in BasePolicy.update instead of in each .learn i…
opcode81 May 2, 2024
ca4dad1
BaseTrainer: Refactoring
opcode81 May 2, 2024
18f2361
Fix invalid kwarg
opcode81 May 2, 2024
ca69e79
Change the way in which deterministic evaluation is controlled:
opcode81 May 2, 2024
c35be8d
Establish backward compatibility by implementing __setstate__
opcode81 May 2, 2024
6927ead
BatchPolicy: check that `self.is_within_training_step` is True on update
May 5, 2024
f876198
Formatting
May 5, 2024
c5d0e16
Collector: removed unnecessary no-grad flag from interfaces. Breaking
May 5, 2024
26a6cca
Improved docstrings, added asserts to make mypy happy
May 5, 2024
82f425e
Collector: move @override, removed docstrings from overridden methods
May 5, 2024
4e38aeb
Merge branch 'refs/heads/thuml-master' into policy-train-eval
May 5, 2024
a8e9df3
Bugfix: allow for training_stat to be None instead of asserting not-None
May 5, 2024
3577969
Clean up handling of an Experiment's name (and, by extension, a run's…
opcode81 Apr 30, 2024
024b80e
Improve creation of multiple seeded experiments:
opcode81 Apr 30, 2024
2abb4da
Reinstated warning module
May 5, 2024
d8e5631
Extended changelog, slightly improved structure
May 5, 2024
f059b65
Merge branch 'refs/heads/thuml-master' into policy-train-eval
May 5, 2024
6a5b3c8
Docstrings, skip hidden files in autogen_rst
May 5, 2024
78ea013
Tests: fixed test_psrl.py: use args.reward_threshold instead of spec
May 6, 2024
e94a5c0
New context manager: policy_within_training_step
May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
- Add methods `to_numpy_` and `to_torch_`. #1098, #1117
- Add `__eq__` (semantic equality check). #1098
- `keys()` deprecated in favor of `get_keys()` (needed to make iteration consistent with naming) #1105.
- `data.collector`:
- `Collector`:
- Introduced `BaseCollector` as a base class for all collectors. #1123
- Add method `close` #1063
- Method `reset` is now more granular (new flags controlling behavior). #1063
- `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063
- `trainer`:
- Trainers can now control whether collectors should be reset prior to training. #1063
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
- policy:
- introduced attribute `in_training_step` that is controlled by the trainer. #1123
- policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123
- `highlevel`:
- `SamplingConfig`:
- Add support for `batch_size=None`. #1077
Expand All @@ -33,12 +37,14 @@
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `utils.net`:
- `continuous.Critic`:
- `utils`:
- `net.continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without
the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128
- `torch_utils` (new module)
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123

### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
Expand All @@ -65,20 +71,26 @@ instead of just `nn.Module`. #1032
- Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102

### Breaking Changes

- Removed `.data` attribute from `Collector` and its child classes. #1063
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- `data`:
- `Collector`:
- Removed `.data` attribute. #1063
- Collectors no longer reset the environment on initialization.
Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063
- Removed `no_grad` argument from `collect` method (was unused in tianshou). #1123
- `Batch`:
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`.
Can be considered a bugfix. #1063
- The methods `to_numpy` and `to_torch` in are not in-place anymore
(use `to_numpy_` or `to_torch_` instead). #1098, #1117
- Logging:
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- The methods `to_numpy` and `to_torch` in `Batch` is not in-place anymore (use `to_numpy_` or `to_torch_` instead). #1098, #1117
- `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074
- `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074


### Tests
Expand Down
4 changes: 1 addition & 3 deletions docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# !pip install tianshou gym"
Expand Down
29 changes: 16 additions & 13 deletions docs/02_notebooks/L4_Policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
")\n",
"from tianshou.utils import RunningMeanStd\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
]
},
{
Expand Down Expand Up @@ -644,7 +645,10 @@
"source": [
"obs, info = env.reset()\n",
"for i in range(3, 10):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" # For retrieving actions to be used for training, we set the policy to training mode,\n",
" # but the wrapped torch module should be in eval mode.\n",
" with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" obs_next, rew, _, truncated, info = env.step(act)\n",
" # pretend this episode never end\n",
" terminated = False\n",
Expand Down Expand Up @@ -695,7 +699,11 @@
},
"source": [
"#### Updates\n",
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train."
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n",
"\n",
"However, we need to manually set the torch module to training mode prior to that, \n",
"and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n",
"but users need to consider it when calling `.update` outside of the trainer."
]
},
{
Expand All @@ -711,16 +719,11 @@
"outputs": [],
"source": [
"# 0 means sample all data from the buffer\n",
"policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "enqlFQLSJrQl"
},
"source": [
"Not that difficult, right?"
"\n",
"# For updating the policy, the policy should be in training mode\n",
"# and the wrapped torch module should also be in training mode (unlike when collecting data).\n",
"with policy_within_training_step(policy), torch_train_mode(policy):\n",
" policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
]
},
{
Expand Down
44 changes: 31 additions & 13 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"id": "do-xZ-8B7nVH",
Expand All @@ -64,9 +63,12 @@
"tags": [
"hide-cell",
"remove-output"
]
],
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
}
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
Expand All @@ -78,14 +80,20 @@
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
]
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
Expand Down Expand Up @@ -123,7 +131,9 @@
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
]
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -154,11 +164,19 @@
"\n",
"n_episode = 10\n",
"for _i in range(n_episode):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" # for test collector, we set the wrapped torch module to evaluation mode\n",
" # by default, the policy object itself is not within the training step\n",
" with torch_train_mode(policy, enabled=False):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" # for collecting data for training, the policy object should be within the training step\n",
" # (affecting e.g. whether the policy is stochastic or deterministic)\n",
" with policy_within_training_step(policy):\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" # for updating the policy, the wrapped torch module should be in training mode\n",
" with torch_train_mode(policy):\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" train_collector.reset_buffer(keep_statistics=True)"
]
},
Expand Down
4 changes: 3 additions & 1 deletion docs/autogen_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
subdir_refs = [
f"{f}/index"
for f in files_in_dir
if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_")
if os.path.isdir(os.path.join(src_root, f))
and not f.startswith("_")
and not f.startswith(".")
]
package_index_rst_path = os.path.join(
rst_root,
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
Expand Down
1 change: 0 additions & 1 deletion examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/bipedal_bdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def stop_fn(mean_rewards: float) -> bool:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def test_fn(epoch: int, env_step: int | None) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
Expand Down
1 change: 0 additions & 1 deletion examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def stop_fn(mean_rewards: float) -> bool:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/discrete/discrete_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def stop_fn(mean_rewards: float) -> bool:
print(f"Finished training in {result.timing.total_time} seconds")

# watch performance
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=100, render=1 / 35)
Expand Down
1 change: 0 additions & 1 deletion examples/discrete/discrete_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def main() -> None:
DQNExperimentBuilder(
EnvFactoryRegistered(
task="CartPole-v1",
seed=0,
venv_type=VectorEnvType.DUMMY,
train_seed=0,
test_seed=10,
Expand Down
1 change: 0 additions & 1 deletion examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/fetch_her_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def save_best_fn(policy: BasePolicy) -> None:
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
Expand Down
Loading
Loading