diff --git a/CHANGELOG.md b/CHANGELOG.md index 35db33ab4..807a9dad6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,14 +8,18 @@ - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098 - Add methods `to_numpy_` and `to_torch_`. #1098, #1117 - Add `__eq__` (semantic equality check). #1098 + - `keys()` deprecated in favor of `get_keys()` (needed to make iteration consistent with naming) #1105. - `data.collector`: - `Collector`: + - Introduced `BaseCollector` as a base class for all collectors. #1123 - Add method `close` #1063 - Method `reset` is now more granular (new flags controlling behavior). #1063 - `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063 - `trainer`: - Trainers can now control whether collectors should be reset prior to training. #1063 -- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105. +- policy: + - introduced attribute `in_training_step` that is controlled by the trainer. #1123 + - policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123 - `highlevel`: - `SamplingConfig`: - Add support for `batch_size=None`. #1077 @@ -33,12 +37,14 @@ - `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 - The module `evaluation.launchers` for parallelization is currently in alpha state. - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 -- `utils.net`: - - `continuous.Critic`: +- `utils`: + - `net.continuous.Critic`: - Add flag `apply_preprocess_net_to_obs_only` to allow the preprocessing network to be applied to the observations only (without the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network #1128 + - `torch_utils` (new module) + - Added context managers `torch_train_mode` and `policy_within_training_step` #1123 ### Fixes - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, @@ -65,20 +71,26 @@ instead of just `nn.Module`. #1032 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 ### Breaking Changes - -- Removed `.data` attribute from `Collector` and its child classes. #1063 -- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` -expicitly or pass `reset_before_collect=True` . #1063 +- `data`: + - `Collector`: + - Removed `.data` attribute. #1063 + - Collectors no longer reset the environment on initialization. + Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 + - Removed `no_grad` argument from `collect` method (was unused in tianshou). #1123 + - `Batch`: + - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. + Can be considered a bugfix. #1063 + - The methods `to_numpy` and `to_torch` in are not in-place anymore + (use `to_numpy_` or `to_torch_` instead). #1098, #1117 +- Logging: + - `BaseLogger.prepare_dict_for_logging` is now abstract. #1074 + - Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 -- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 -- The methods `to_numpy` and `to_torch` in `Batch` is not in-place anymore (use `to_numpy_` or `to_torch_` instead). #1098, #1117 - `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074 - `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074 -- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074 -- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074 ### Tests diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 59d6fd207..0ce6df154 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -18,9 +18,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "# !pip install tianshou gym" diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 00f7f27b9..eed8ea344 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -74,7 +74,8 @@ ")\n", "from tianshou.utils import RunningMeanStd\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" ] }, { @@ -644,7 +645,10 @@ "source": [ "obs, info = env.reset()\n", "for i in range(3, 10):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", + " # For retrieving actions to be used for training, we set the policy to training mode,\n", + " # but the wrapped torch module should be in eval mode.\n", + " with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n", + " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", " obs_next, rew, _, truncated, info = env.step(act)\n", " # pretend this episode never end\n", " terminated = False\n", @@ -695,7 +699,11 @@ }, "source": [ "#### Updates\n", - "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train." + "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n", + "\n", + "However, we need to manually set the torch module to training mode prior to that, \n", + "and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n", + "but users need to consider it when calling `.update` outside of the trainer." ] }, { @@ -711,16 +719,11 @@ "outputs": [], "source": [ "# 0 means sample all data from the buffer\n", - "policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "enqlFQLSJrQl" - }, - "source": [ - "Not that difficult, right?" + "\n", + "# For updating the policy, the policy should be in training mode\n", + "# and the wrapped torch module should also be in training mode (unlike when collecting data).\n", + "with policy_within_training_step(policy), torch_train_mode(policy):\n", + " policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" ] }, { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index 75aea471c..d5423bd01 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -54,7 +54,6 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": { "editable": true, "id": "do-xZ-8B7nVH", @@ -64,9 +63,12 @@ "tags": [ "hide-cell", "remove-output" - ] + ], + "ExecuteTime": { + "end_time": "2024-05-06T15:34:02.969675Z", + "start_time": "2024-05-06T15:34:00.747309Z" + } }, - "outputs": [], "source": [ "%%capture\n", "\n", @@ -78,14 +80,20 @@ "from tianshou.policy import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" - ] + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-06T15:34:07.536452Z", + "start_time": "2024-05-06T15:34:03.636670Z" + } + }, "source": [ "train_env_num = 4\n", "buffer_size = (\n", @@ -123,7 +131,9 @@ "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "test_collector = Collector(policy, test_envs)\n", "train_collector = Collector(policy, train_envs, replayBuffer)" - ] + ], + "outputs": [], + "execution_count": 2 }, { "cell_type": "markdown", @@ -154,11 +164,19 @@ "\n", "n_episode = 10\n", "for _i in range(n_episode):\n", - " evaluation_result = test_collector.collect(n_episode=n_episode)\n", + " # for test collector, we set the wrapped torch module to evaluation mode\n", + " # by default, the policy object itself is not within the training step\n", + " with torch_train_mode(policy, enabled=False):\n", + " evaluation_result = test_collector.collect(n_episode=n_episode)\n", " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", + " # for collecting data for training, the policy object should be within the training step\n", + " # (affecting e.g. whether the policy is stochastic or deterministic)\n", + " with policy_within_training_step(policy):\n", + " train_collector.collect(n_step=2000)\n", + " # 0 means taking all data stored in train_collector.buffer\n", + " # for updating the policy, the wrapped torch module should be in training mode\n", + " with torch_train_mode(policy):\n", + " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", " train_collector.reset_buffer(keep_statistics=True)" ] }, diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index d3d4f080c..b1a8b18d9 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -74,7 +74,9 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" subdir_refs = [ f"{f}/index" for f in files_in_dir - if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_") + if os.path.isdir(os.path.join(src_root, f)) + and not f.startswith("_") + and not f.startswith(".") ] package_index_rst_path = os.path.join( rst_root, diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index fc04a219a..d611ab196 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index f237c5a33..eeb9bccce 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -204,7 +204,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 127a14b24..58aff46ac 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -175,7 +175,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 8b1625275..c6090523d 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -172,7 +172,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index f1a89ef40..dd75de7fb 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -229,7 +229,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index dfb96419e..b9731316e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -166,7 +166,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 7b341c0a1..952d35f07 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -200,7 +200,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index f06964c28..4d01a88aa 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -216,7 +216,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 96e61b612..365c073fa 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -144,7 +144,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 8b1e8ca8d..c817831b1 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 2c071bc1c..66e5f316d 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -207,7 +207,6 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 9e5db5833..f9bbd6fa6 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -141,7 +141,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 5c093093e..7617b7b43 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -153,7 +153,6 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 4f1a82b12..3ba22a40c 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -80,7 +80,6 @@ def stop_fn(mean_rewards: float) -> bool: print(f"Finished training in {result.timing.total_time} seconds") # watch performance - policy.eval() policy.set_eps(eps_test) collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 35e359770..eacf4c78f 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -18,7 +18,6 @@ def main() -> None: DQNExperimentBuilder( EnvFactoryRegistered( task="CartPole-v1", - seed=0, venv_type=VectorEnvType.DUMMY, train_seed=0, test_seed=10, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 2d013a01b..42e5bc2c9 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -264,7 +264,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index cd6ceec89..bbf68c2fa 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -238,7 +238,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index ea6ab8f24..194d9b5de 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -219,7 +219,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index b2a40878b..db90babb0 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -168,7 +168,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 8a379da92..4d8530a53 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -216,7 +216,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 00042884f..7c3f268c8 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -224,7 +224,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index ae46b220c..8951b03ac 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -196,7 +196,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 109f1cc46..ff7e34099 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -196,7 +196,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index a0bd567ff..af1398380 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -190,7 +190,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 6b6dfdc8c..6cc8bb212 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -188,7 +188,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 219593343..eefdfcc65 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -221,7 +221,6 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 1fc0dc7e3..3af40cc7f 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -187,7 +187,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 40d91c1bb..b2c0c8705 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -171,7 +171,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index a4b31c4fb..8b6320a79 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -188,7 +188,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index bb7822ea9..39aee31d5 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -145,7 +145,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 80b233cb7..9ed18262a 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -206,7 +206,6 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) @@ -229,7 +228,6 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 7ca8ae2fb..90d6b159c 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -344,7 +344,6 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) @@ -367,7 +366,6 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index c2152a711..e03deed80 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -142,7 +142,6 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) @@ -165,7 +164,6 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 4d6159ff5..6b448b320 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -191,7 +191,6 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) @@ -217,7 +216,6 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 4211585af..25ad80487 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -168,7 +168,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index f5abf0b6f..7fc09f690 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -231,7 +231,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") diff --git a/test/base/test_action_space_sampling.py b/test/base/test_action_space_sampling.py index fbbf25c04..e3d82767a 100644 --- a/test/base/test_action_space_sampling.py +++ b/test/base/test_action_space_sampling.py @@ -48,10 +48,3 @@ def test_shmem_vec_env_action_space() -> None: action2 = [ac_space.sample() for ac_space in envs.action_space] assert action1 == action2 - - -if __name__ == "__main__": - test_gym_env_action_space() - test_dummy_vec_env_action_space() - test_subproc_vec_env_action_space() - test_shmem_vec_env_action_space() diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 5e90dfb66..86d4af500 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -749,16 +749,3 @@ def test_to_torch_() -> None: assert id_batch == id(batch) assert isinstance(batch.b, torch.Tensor) assert isinstance(batch.c.d, torch.Tensor) - - -if __name__ == "__main__": - test_batch() - test_batch_over_batch() - test_batch_over_batch_to_torch() - test_utils_to_torch_numpy() - test_batch_pickle() - test_batch_from_to_numpy_without_copy() - test_batch_standard_compatibility() - test_batch_cat_and_stack() - test_batch_copy() - test_batch_empty() diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 5488ff365..1b3593db3 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,7 +1,8 @@ import os import pickle import tempfile -from timeit import timeit +from test.base.env import MoveToRightEnv, MyGoalEnv +from typing import cast import h5py import numpy as np @@ -22,11 +23,6 @@ ) from tianshou.data.utils.converter import to_hdf5 -if __name__ == "__main__": - from env import MoveToRightEnv, MyGoalEnv -else: # pytest - from test.base.env import MoveToRightEnv, MyGoalEnv - def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: env = MoveToRightEnv(size) @@ -386,25 +382,25 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf[tmp_indices].obs - obs_next = buf[tmp_indices].obs_next - rew = buf[tmp_indices].rew - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf[tmp_indices].obs) + obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) + rew_in_buf = buf[tmp_indices].rew + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g[0]) assert np.all(g_next == g_next[0]) - assert np.all(rew == (ag_next == g).astype(np.float32)) + assert np.all(rew_in_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf.next(tmp_indices) # Check that goals are correctly restored buf._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf[tmp_indices].obs - obs_next = buf[tmp_indices].obs_next - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf[tmp_indices].obs) + obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -416,24 +412,24 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf2[tmp_indices].obs - obs_next = buf2[tmp_indices].obs_next - rew = buf2[tmp_indices].rew - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf2[tmp_indices].obs) + obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) + rew_buf = buf2[tmp_indices].rew + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g_next) - assert np.all(rew == (ag_next == g).astype(np.float32)) + assert np.all(rew_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf2.next(tmp_indices) # Check that goals are correctly restored buf2._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf2[tmp_indices].obs - obs_next = buf2[tmp_indices].obs_next - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf2[tmp_indices].obs) + obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -447,7 +443,6 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf._index = 5 # shifted start index buf.future_p = 1 - action_list = [1] * 10 for ep_len in [5, 10]: obs, _ = env.reset() for i in range(ep_len): @@ -607,24 +602,6 @@ def test_segtree() -> None: index = tree.get_prefix_sum_idx(scalar) assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() - # profile - if __name__ == "__main__": - size = 100000 - bsz = 64 - naive = np.random.rand(size) - tree = SegmentTree(size) - tree[np.arange(size)] = naive - - def sample_npbuf() -> np.ndarray: - return np.random.choice(size, bsz, p=naive / naive.sum()) - - def sample_tree() -> int | np.ndarray: - scalar = np.random.rand(bsz) * tree.reduce() - return tree.get_prefix_sum_idx(scalar) - - print("npbuf", timeit(sample_npbuf, setup=sample_npbuf, number=1000)) - print("tree", timeit(sample_tree, setup=sample_tree, number=1000)) - def test_pickle() -> None: size = 100 @@ -1053,6 +1030,7 @@ def test_multibuf_stack() -> None: size, ) obs, info = env.reset(options={"state": 1}) + obs = cast(np.ndarray, obs) for i in range(18): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated @@ -1080,7 +1058,8 @@ def test_multibuf_stack() -> None: assert np.all(buf4.truncated == buf5.truncated) obs = obs_next if done: - obs, info = env.reset(options={"state": 1}) + # obs is an array, but the env is malformed, so we can't properly type it + obs, info = env.reset(options={"state": 1}) # type: ignore[assignment] # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), @@ -1401,21 +1380,3 @@ def test_custom_key() -> None: ): assert batch.__dict__[key].is_empty() assert sampled_batch.__dict__[key].is_empty() - - -if __name__ == "__main__": - test_replaybuffer() - test_ignore_obs_next() - test_stack() - test_segtree() - test_priortized_replaybuffer() - test_update() - test_pickle() - test_hdf5() - test_replaybuffermanager() - test_cachedbuffer() - test_multibuf_stack() - test_multibuf_hdf5() - test_from_data() - test_herreplaybuffer() - test_custom_key() diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6baa6abf3..d03a54df7 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from test.base.env import MoveToRightEnv, NXEnv from typing import Any import gymnasium as gym @@ -25,11 +26,6 @@ except ImportError: envpool = None -if __name__ == "__main__": - from env import MoveToRightEnv, NXEnv -else: # pytest - from test.base.env import MoveToRightEnv, NXEnv - class MaxActionPolicy(BasePolicy): def __init__( @@ -222,11 +218,11 @@ def test_collector() -> None: c_dummy_venv_4_envs.collect(n_episode=4, random=True) # test corner case - with pytest.raises(TypeError): + with pytest.raises(ValueError): Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: @@ -264,7 +260,7 @@ def test_collect_without_argument_gives_error( async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens - with pytest.raises(TypeError): + with pytest.raises(ValueError): c1.collect() def test_collect_one_episode_async( @@ -963,13 +959,3 @@ def test_async_collector_with_vector_env() -> None: assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens) c2r = c1.collect(n_step=20) assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens) - - -if __name__ == "__main__": - test_collector() - test_collector_with_dict_state() - test_collector_with_multi_agent() - test_collector_with_atari_setting() - test_collector_envpool_gym_reset_return_info() - test_collector_with_vector_env() - test_async_collector_with_vector_env() diff --git a/test/base/test_env.py b/test/base/test_env.py index a476ec5a9..1a33e861c 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,6 +1,7 @@ import sys import time from collections.abc import Callable +from test.base.env import MoveToRightEnv, NXEnv from typing import Any, Literal import gymnasium as gym @@ -22,11 +23,6 @@ from tianshou.env.venvs import BaseVectorEnv from tianshou.utils import RunningMeanStd -if __name__ == "__main__": - from env import MoveToRightEnv, NXEnv -else: # pytest - from test.base.env import MoveToRightEnv, NXEnv - try: import envpool except ImportError: @@ -190,19 +186,6 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: for info in infos: assert recurse_comp(infos[0], info) - if __name__ == "__main__": - t = [0.0] * len(venv) - for i, e in enumerate(venv): - t[i] = time.time() - e.reset() - for a in action_list: - done = e.step(np.array([a] * num))[2] - if sum(done) > 0: - e.reset(np.where(done)[0]) - t[i] = time.time() - t[i] - for i, v in enumerate(venv): - print(f"{type(v)}: {t[i]:.6f}s") - def assert_get(v: BaseVectorEnv, expected: list) -> None: assert v.get_env_attr("size") == expected assert v.get_env_attr("size", id=0) == [expected[0]] @@ -437,17 +420,3 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None: for _, v in _info.items(): if not isinstance(v, dict): assert v.shape[0] == num_envs - - -if __name__ == "__main__": - test_venv_norm_obs() - test_venv_wrapper_gym() - test_venv_wrapper_envpool() - test_venv_wrapper_envpool_gym_reset_return_info() - test_env_obs_dtype() - test_vecenv() - test_attr_unwrapped() - test_async_env() - test_async_check_id() - test_env_reset_optional_kwargs() - test_gym_wrappers() diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 657100554..ce8a93640 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -268,8 +268,3 @@ def test_finite_subproc_vector_env() -> None: test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() - - -if __name__ == "__main__": - test_finite_dummy_vector_env() - test_finite_subproc_vector_env() diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 7c3aacc07..4d26905c3 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -64,6 +64,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: class TestPolicyBasics: def test_get_action(self, policy: PPOPolicy) -> None: + policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 23f50fb22..ab4430b85 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,10 +1,7 @@ -from timeit import timeit - import numpy as np import torch from tianshou.data import Batch, ReplayBuffer, to_numpy -from tianshou.data.types import BatchWithReturnsProtocol from tianshou.policy import BasePolicy @@ -142,28 +139,6 @@ def test_episodic_returns(size: int = 2560) -> None: ) assert np.allclose(returns, ground_truth) - if __name__ == "__main__": - buf = ReplayBuffer(size) - batch = Batch( - terminated=np.random.randint(100, size=size) == 0, - truncated=np.zeros(size), - rew=np.random.random(size), - ) - for b in iter(batch): - b.obs = b.act = 1 - buf.add(b) - indices = buf.sample_indices(0) - - def vanilla() -> Batch: - return compute_episodic_return_base(batch, gamma=0.1) - - def optimized() -> tuple[np.ndarray, np.ndarray]: - return fn(batch, buf, indices, gamma=0.1, gae_lambda=1.0) - - cnt = 3000 - print("GAE vanilla", timeit(vanilla, setup=vanilla, number=cnt)) - print("GAE optim ", timeit(optimized, setup=optimized, number=cnt)) - def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: # return the next reward @@ -356,41 +331,3 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) - - if __name__ == "__main__": - buf = ReplayBuffer(size) - for i in range(int(size * 1.5)): - buf.add( - Batch( - obs=0, - act=0, - rew=i + 1, - terminated=np.random.randint(3) == 0, - truncated=i % 33 == 0, - info={}, - ), - ) - batch, indices = buf.sample(256) - - def vanilla() -> np.ndarray: - return compute_nstep_return_base(3, 0.1, buf, indices) - - def optimized() -> BatchWithReturnsProtocol: - return BasePolicy.compute_nstep_return( - batch, - buf, - indices, - target_q_fn, - gamma=0.1, - n_step=3, - ) - - cnt = 3000 - print("nstep vanilla", timeit(vanilla, setup=vanilla, number=cnt)) - print("nstep optim ", timeit(optimized, setup=optimized, number=cnt)) - - -if __name__ == "__main__": - test_nstep_returns() - test_nstep_returns_with_timelimit() - test_episodic_returns() diff --git a/test/base/test_utils.py b/test/base/test_utils.py index bd14ffe2a..f8e5938cb 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,10 +1,12 @@ import numpy as np import torch +from torch import nn from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic +from tianshou.utils.torch_utils import torch_train_mode def test_noise() -> None: @@ -132,9 +134,17 @@ def test_lr_schedulers() -> None: ) -if __name__ == "__main__": - test_noise() - test_moving_average() - test_rms() - test_net() - test_lr_schedulers() +def test_in_eval_mode() -> None: + module = nn.Linear(3, 4) + module.train() + with torch_train_mode(module, False): + assert not module.training + assert module.training + + +def test_in_train_mode() -> None: + module = nn.Linear(3, 4) + module.eval() + with torch_train_mode(module): + assert module.training + assert not module.training diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index a17c3b513..1aedadabf 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -133,16 +132,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_ddpg() diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 8e0a50d2c..98803a9d2 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -155,16 +154,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_npg() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 5a522dedb..15d834096 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -191,20 +190,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: assert stop_fn(epoch_stat.info_stat.best_reward) - if __name__ == "__main__": - pprint.pprint(epoch_stat) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_ppo(args) - - -if __name__ == "__main__": - test_ppo() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 697b59e98..f627f7e4f 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -164,16 +163,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_redq() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index d13b03d85..77a403359 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -161,7 +161,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy - policy.eval() if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal il_net = Net( @@ -205,7 +204,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - -if __name__ == "__main__": - test_sac_with_il() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ea55da052..21a2cf40d 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -155,17 +155,3 @@ def stop_fn(mean_rewards: float) -> bool: # print(info) assert stop_fn(epoch_stat.info_stat.best_reward) - - if __name__ == "__main__": - pprint.pprint(epoch_stat.info_stat) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_td3() diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index ae788d1cc..8841891bf 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -155,16 +154,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_trpo() diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f60857ea4..2fd41aff8 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,15 +1,14 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np -import pytest import torch from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer @@ -25,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -60,29 +59,35 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: - # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make( - args.task, - env_type="gymnasium", - num_envs=args.training_num, - seed=args.seed, - ) - test_envs = envpool.make( - args.task, - env_type="gymnasium", - num_envs=args.test_num, - seed=args.seed, - ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + if envpool is not None: + train_envs = env = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.training_num, + seed=args.seed, + ) + test_envs = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.test_num, + seed=args.seed, + ) + else: + env = gym.make(args.task) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs.seed(args.seed) + test_envs.seed(args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) @@ -141,18 +146,8 @@ def stop_fn(mean_rewards: float) -> bool: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - policy.eval() # here we define an imitation collector with a trivial policy - # if args.task == 'CartPole-v0': + # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) @@ -162,9 +157,23 @@ def stop_fn(mean_rewards: float) -> bool: optim=optim, action_space=env.action_space, ) + if envpool is not None: + il_env = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.test_num, + seed=args.seed, + ) + else: + il_env = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + context="fork", + ) + il_env.seed(args.seed) + il_test_collector = Collector( il_policy, - envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed), + il_env, ) train_collector.reset() result = OffpolicyTrainer( @@ -181,16 +190,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - il_policy.eval() - collector = Collector(il_policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_a2c_with_il() diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 1089d4ba0..91c66bac0 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,5 +1,4 @@ import argparse -import pprint import gymnasium as gym import numpy as np @@ -129,7 +128,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OffpolicyTrainer( + OffpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, @@ -143,18 +142,3 @@ def stop_fn(mean_rewards: float) -> bool: test_fn=test_fn, stop_fn=stop_fn, ).run() - - # assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - policy.set_eps(args.eps_test) - test_envs.seed(args.seed) - test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) - collector_stats.pprint_asdict() - - -if __name__ == "__main__": - test_bdq(get_args()) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 4d25d430b..8b34ddb4b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,7 +1,6 @@ import argparse import os import pickle -import pprint import gymnasium as gym import numpy as np @@ -25,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -68,7 +67,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -202,16 +201,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_c51_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True @@ -223,7 +212,3 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_c51(args) - - -if __name__ == "__main__": - test_c51(get_args()) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b62a93c3f..773004f2c 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -24,7 +23,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -62,7 +61,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -155,23 +154,9 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_pdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 test_dqn(args) - - -if __name__ == "__main__": - test_dqn(get_args()) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5c24518bb..193179097 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -19,7 +18,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -55,7 +54,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -131,16 +130,3 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_drqn(get_args()) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 8ff9eeb7a..743293be0 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -25,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +66,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -172,22 +171,8 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_pfqf(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_fqf(args) - - -if __name__ == "__main__": - test_fqf(get_args()) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 765bbf9bd..f7ea67adb 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -25,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +66,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -168,22 +167,8 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_piqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_iqn(args) - - -if __name__ == "__main__": - test_iqn(get_args()) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 95db43c23..60d0eb469 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -20,7 +19,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -51,7 +50,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -124,16 +123,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_pg() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 132cbea5a..27fe6f517 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -23,7 +22,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer-size", type=int, default=20000) @@ -64,7 +63,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -151,16 +150,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_ppo() diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 6485637e8..76d7d429d 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -20,7 +19,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -60,10 +59,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - if args.task == "CartPole-v0" and env.spec: + if args.task == "CartPole-v1" and env.spec: env.spec.reward_threshold = 190 # lower the goal if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -157,22 +156,8 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_qrdqn(args) - - -if __name__ == "__main__": - test_pqrdqn(get_args()) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index ff4ef1c1e..0a73d4b77 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,7 +1,6 @@ import argparse import os import pickle -import pprint import gymnasium as gym import numpy as np @@ -22,7 +21,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -69,7 +68,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -219,16 +218,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True @@ -240,7 +229,3 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_rainbow(args) - - -if __name__ == "__main__": - test_rainbow(get_args()) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index b2f466f3d..f16e59daf 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -21,7 +20,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -60,7 +59,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} # lower the goal + default_reward_threshold = {"CartPole-v1": 170} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -142,16 +141,3 @@ def stop_fn(mean_rewards: float) -> bool: test_in_train=False, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_discrete_sac() diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index ddfce7b23..4a131e5fd 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -7,7 +7,7 @@ class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( - task="CartPole-v0", + task="CartPole-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9a4206e18..5ef0bba65 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -21,7 +20,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -79,7 +78,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -197,17 +196,3 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_dqn_icm(get_args()) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index ebf93cd5a..77f9a40e1 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -22,7 +21,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer-size", type=int, default=20000) @@ -83,7 +82,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -189,16 +188,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_ppo() diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 72742b785..995aef698 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import numpy as np import pytest @@ -45,7 +44,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +@pytest.mark.skipif( + envpool is None, + reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)", +) def test_psrl(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) @@ -116,18 +118,4 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, test_in_train=False, ).run() - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - test_envs.seed(args.seed) - test_collector.reset() - stats = test_collector.collect(n_episode=args.test_num, render=args.render) - stats.pprint_asdict() - elif env.spec.reward_threshold: - assert result.best_reward >= env.spec.reward_threshold - - -if __name__ == "__main__": - test_psrl() + assert result.best_reward >= args.reward_threshold diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 93877944e..19ba653e5 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -19,12 +19,12 @@ def expert_file_name() -> str: - return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") + return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 190} + default_reward_threshold = {"CartPole-v1": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 1839d863a..8b31c1969 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -2,7 +2,7 @@ import datetime import os import pickle -import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -19,11 +19,6 @@ from tianshou.utils.net.continuous import VAE, Critic, Perturbation from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -189,7 +184,6 @@ def watch() -> None: policy.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) - policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) @@ -208,16 +202,3 @@ def watch() -> None: show_progress=args.show_progress, ).run() assert stop_fn(result.best_reward) - - # Let's watch its performance! - if __name__ == "__main__": - pprint.pprint(result) - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_bcq() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 1e31b1feb..41d67151a 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,6 +3,7 @@ import os import pickle import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -19,11 +20,6 @@ from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -205,19 +201,3 @@ def stop_fn(mean_rewards: float) -> bool: # print(info) assert stop_fn(epoch_stat.info_stat.best_reward) - - # Let's watch its performance! - if __name__ == "__main__": - pprint.pprint(epoch_stat.info_stat) - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render) - if collector_result.returns_stat and collector_result.lens_stat: - print( - f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", - ) - - -if __name__ == "__main__": - test_cql() diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 77790808b..6e8e8784a 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -1,14 +1,14 @@ import argparse import os import pickle -import pprint +from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteBCQPolicy from tianshou.trainer import OfflineTrainer @@ -17,15 +17,10 @@ from tianshou.utils.net.discrete import Actor from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_cartpole_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_cartpole_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -61,7 +56,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 185} + default_reward_threshold = {"CartPole-v1": 185} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -101,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: imitation_logits_penalty=args.imitation_logits_penalty, ) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -165,22 +161,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: test_discrete_bcq() args.resume = True test_discrete_bcq(args) - - -if __name__ == "__main__": - test_discrete_bcq(get_args()) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 7323eac13..f2a60e00c 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -1,14 +1,14 @@ import argparse import os import pickle -import pprint +from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCQLPolicy from tianshou.trainer import OfflineTrainer @@ -16,15 +16,10 @@ from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_cartpole_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_cartpole_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -58,7 +53,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} + default_reward_threshold = {"CartPole-v1": 170} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -90,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: min_q_weight=args.min_q_weight, ).to(args.device) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -126,17 +122,3 @@ def stop_fn(mean_rewards: float) -> bool: ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_discrete_cql(get_args()) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index b3cb64616..bc54dd9d0 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -1,14 +1,14 @@ import argparse import os import pickle -import pprint +from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCRRPolicy from tianshou.trainer import OfflineTrainer @@ -17,15 +17,10 @@ from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_cartpole_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_cartpole_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) @@ -56,7 +51,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 180} + default_reward_threshold = {"CartPole-v1": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -94,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -130,16 +126,3 @@ def stop_fn(mean_rewards: float) -> bool: ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_discrete_crr(get_args()) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 256140c41..ea13f484c 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -1,7 +1,7 @@ import argparse import os import pickle -import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -18,11 +18,6 @@ from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -226,16 +221,3 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: save_checkpoint_fn=save_checkpoint_fn, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_gail() diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 18778563c..fa01444ab 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -2,7 +2,7 @@ import datetime import os import pickle -import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -20,11 +20,6 @@ from tianshou.utils.net.continuous import Actor, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -193,16 +188,3 @@ def stop_fn(mean_rewards: float) -> bool: # print(info) assert stop_fn(epoch_stat.info_stat.best_reward) - - # Let's watch its performance! - if __name__ == "__main__": - pprint.pprint(epoch_stat.info_stat) - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - test_td3_bc() diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index abd0c889a..c57522df0 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -188,7 +188,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - policy.eval() [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 54d606602..38de81173 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) collector_result.pprint_asdict() diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 2432ca531..36f1d2d3a 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -1,22 +1,14 @@ import argparse -import pprint +import pytest from pistonball import get_args, train_agent, watch +@pytest.mark.skip(reason="Performance bound was never tested, no point in running this for now") def test_piston_ball(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return - result, agent = train_agent(args) + train_agent(args) # assert result.best_reward >= args.win_rate - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == "__main__": - test_piston_ball(get_args()) diff --git a/test/pettingzoo/test_pistonball_continuous.py b/test/pettingzoo/test_pistonball_continuous.py index bb2979bc6..b96a29c0c 100644 --- a/test/pettingzoo/test_pistonball_continuous.py +++ b/test/pettingzoo/test_pistonball_continuous.py @@ -1,5 +1,4 @@ import argparse -import pprint import pytest from pistonball_continuous import get_args, train_agent, watch @@ -13,12 +12,3 @@ def test_piston_ball_continuous(args: argparse.Namespace = get_args()) -> None: result, agent = train_agent(args) # assert result.best_reward >= 30.0 - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == "__main__": - test_piston_ball_continuous(get_args()) diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index 44aa86b9f..0f0c237c8 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -1,5 +1,4 @@ import argparse -import pprint from tic_tac_toe import get_args, train_agent, watch @@ -11,12 +10,3 @@ def test_tic_tac_toe(args: argparse.Namespace = get_args()) -> None: result, agent = train_agent(args) assert result.best_reward >= args.win_rate - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == "__main__": - test_tic_tac_toe(get_args()) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 7ed631912..966c9e04c 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -228,7 +228,6 @@ def watch( ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) - policy.eval() policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 623079890..c84c2ec7d 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -24,7 +24,13 @@ SequenceSummaryStats, TimingStats, ) -from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase +from tianshou.data.collector import ( + Collector, + AsyncCollector, + CollectStats, + CollectStatsBase, + BaseCollector, +) __all__ = [ "Batch", @@ -50,4 +56,5 @@ "InfoStats", "SequenceSummaryStats", "TimingStats", + "BaseCollector", ] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 345d50b03..6773a6383 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,5 +1,7 @@ +import logging import time import warnings +from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass from typing import Any, Self, TypeVar, cast @@ -7,11 +9,11 @@ import gymnasium as gym import numpy as np import torch +from overrides import override from tianshou.data import ( Batch, CachedReplayBuffer, - PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, SequenceSummaryStats, @@ -25,6 +27,9 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import torch_train_mode + +log = logging.getLogger(__name__) @dataclass(kw_only=True) @@ -122,23 +127,12 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class Collector: - """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` - as the default buffer. - :param exploration_noise: determine whether the action needs to be modified - with the corresponding policy's exploration noise. If so, "policy. - exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action. Default to False. +class BaseCollector(ABC): + """Used to collect data from a vector environment into a buffer using a given policy. .. note:: - Please make sure the given environment has a time limitation if using n_episode + Please make sure the given environment has a time limitation if using `n_episode` collect option. .. note:: @@ -150,72 +144,70 @@ class Collector: def __init__( self, policy: BasePolicy, - env: gym.Env | BaseVectorEnv, + env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: - super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy - self.env = DummyVectorEnv([lambda: env]) - else: - self.env = env # type: ignore - self.env_num = len(self.env) - self.exploration_noise = exploration_noise - self.buffer = self._assign_buffer(buffer) + env = DummyVectorEnv([lambda: env]) # type: ignore + + if buffer is None: + buffer = VectorReplayBuffer(len(env), len(env)) + + self.buffer: ReplayBuffer = buffer self.policy = policy + self.env = cast(BaseVectorEnv, env) + self.exploration_noise = exploration_noise + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + self._action_space = self.env.action_space + self._is_closed = False - self._pre_collect_obs_RO: np.ndarray | None = None - self._pre_collect_info_R: np.ndarray | None = None - self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + self._validate_buffer() - self._is_closed = False - self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + def _validate_buffer(self) -> None: + buf = self.buffer + # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. + # We should probably rename the manager + if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.", + ) + if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.", + ) + # Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance + if not isinstance(buf, ReplayBufferManager): + if buf.maxsize == 0: + raise ValueError("Buffer maxsize should be greater than 0.") + if self.env_num > 1: + raise ValueError( + f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). " + f"Please use the corresponding VectorReplayBuffer instead.", + ) + + @property + def env_num(self) -> int: + return len(self.env) + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space def close(self) -> None: """Close the collector and the environment.""" self.env.close() - self._pre_collect_obs_RO = None - self._pre_collect_info_R = None self._is_closed = True - @property - def is_closed(self) -> bool: - """Return True if the collector is closed.""" - return self._is_closed - - def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: - """Check if the buffer matches the constraint.""" - if buffer is None: - buffer = VectorReplayBuffer(self.env_num, self.env_num) - elif isinstance(buffer, ReplayBufferManager): - assert buffer.buffer_num >= self.env_num - if isinstance(buffer, CachedReplayBuffer): - assert buffer.cached_buffer_num >= self.env_num - else: # ReplayBuffer or PrioritizedReplayBuffer - assert buffer.maxsize > 0 - if self.env_num > 1: - if isinstance(buffer, ReplayBuffer): - buffer_type = "ReplayBuffer" - vector_type = "VectorReplayBuffer" - if isinstance(buffer, PrioritizedReplayBuffer): - buffer_type = "PrioritizedReplayBuffer" - vector_type = "PrioritizedVectorReplayBuffer" - raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " - f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" - f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", - ) - return buffer - def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached @@ -223,13 +215,15 @@ def reset( :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) + :return: The initial observation and info from the environment. """ - self.reset_env(gym_reset_kwargs=gym_reset_kwargs) + obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: self.reset_buffer() if reset_stats: self.reset_stat() self._is_closed = False + return obs_NO, info_N def reset_stat(self) -> None: """Reset the statistic variables.""" @@ -242,24 +236,165 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: - """Reset the environments and the initial obs, info, and hidden state of the collector.""" + ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environments and the initial obs, info, and hidden state of the collector. + + :return: The initial observation and info from the (vectorized) environment. + """ gym_reset_kwargs = gym_reset_kwargs or {} - self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs) + obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict - if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable] + if isinstance(info_N, dict): # type: ignore[unreachable] # this can happen if the env is an envpool env. Then the thing returned by reset is a dict # with array entries instead of an array of dicts # We use Batch to turn it into an array of dicts - self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] + info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable] + return obs_NO, info_N + + @abstractmethod + def _collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + pass + + @torch.no_grad() + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + """Collect a specified number of steps or episodes. + + To ensure an unbiased sampling result with the n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + + :param n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param random: whether to use random policy for collecting data. + :param render: the sleep time between rendering consecutive frames. + :param reset_before_collect: whether to reset the environment before collecting data. + (The collector needs the initial obs and info to function properly.) + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Only used if reset_before_collect is True. + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: The collected stats + """ + # check that exactly one of n_step or n_episode is set and that the other is larger than 0 + self._validate_n_step_n_episode(n_episode, n_step) + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + with torch_train_mode(self.policy, False): + return self._collect( + n_step=n_step, + n_episode=n_episode, + random=random, + render=render, + gym_reset_kwargs=gym_reset_kwargs, + ) + + def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: + if not n_step and not n_episode: + raise ValueError( + f"Only one of n_step and n_episode should be set to a value larger than zero " + f"but got {n_step=}, {n_episode=}.", + ) + if n_step is None and n_episode is None: + raise ValueError( + "Exactly one of n_step and n_episode should be set but got None for both.", + ) + if n_step and n_step % self.env_num != 0: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}), " + "which may cause extra transitions being collected into the buffer.", + ) + if n_episode and self.env_num > n_episode: + warnings.warn( + f"{n_episode=} should be larger than {self.env_num=} to " + f"collect at least one trajectory in each environment.", + ) + + +class Collector(BaseCollector): + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # N - number of envs, always fixed and >= R. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + """ + super().__init__(policy, env, buffer, exploration_noise=exploration_noise) + self._pre_collect_obs_RO: np.ndarray | None = None + self._pre_collect_info_R: np.ndarray | None = None + self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + self._is_closed = False + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + @override + def close(self) -> None: + super().close() + self._pre_collect_obs_RO = None + self._pre_collect_info_R = None + + @override + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + # We assume that R = N when reset is called. + # TODO: there is currently no mechanism that ensures this and it's a public method! + self._pre_collect_obs_RO = obs_NO + self._pre_collect_info_R = info_N self._pre_collect_hidden_state_RH = None + return obs_NO, info_N def _compute_action_policy_hidden( self, random: bool, ready_env_ids_R: np.ndarray, - use_grad: bool, last_obs_RO: np.ndarray, last_info_R: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, @@ -281,11 +416,10 @@ def _compute_action_policy_hidden( info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - with torch.set_grad_enabled(use_grad): - act_batch_RA = self.policy( - obs_batch_R, - last_hidden_state_RH, - ) + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: @@ -309,89 +443,29 @@ def _compute_action_policy_hidden( return act_RA, act_normalized_RA, policy_R, hidden_state_RH # TODO: reduce complexity, remove the noqa - def collect( + def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, - reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of steps or episodes. - - To ensure an unbiased sampling result with the n_episode option, this function will - first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` - episodes, they will be collected evenly from each env. - - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. - :param render: the sleep time between rendering consecutive frames. - :param no_grad: whether to retain gradient in policy.forward(). - :param reset_before_collect: whether to reset the environment before collecting data. - (The collector needs the initial obs and info to function properly.) - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Only used if reset_before_collect is True. - - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. - - :return: The collected stats - """ - # NAMING CONVENTION (mostly suffixes): - # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, - # the corresponding env is either reset or removed from the ready envs. - # R - number ready env ids. Note that this might change when envs get idle. - # This can only happen in n_episode case, see explanation in the corresponding block. - # For n_step, we always use all envs to collect the data, while for n_episode, - # R will be at most n_episode at the beginning, but can decrease during the collection. - # O - dimension(s) of observations - # A - dimension(s) of actions - # H - dimension(s) of hidden state - # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. - # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. - # Only used in n_episode case. Then, R becomes R-S. - - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} + # TODO: can't do it init since AsyncCollector is currently a subclass of Collector + if self.env.is_async: + raise ValueError( + f"Please use {AsyncCollector.__name__} for asynchronous environments. " + f"Env class: {self.env.__class__.__name__}.", + ) - # Input validation - assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: - assert n_episode is None, ( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}." - ) - assert n_step > 0 - if n_step % self.env_num != 0: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}), " - "which may cause extra transitions being collected into the buffer.", - ) ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: - assert n_episode > 0 - if self.env_num > n_episode: - warnings.warn( - f"{n_episode=} should be larger than {self.env_num=} to " - f"collect at least one trajectory in each environment.", - ) ready_env_ids_R = np.arange(min(self.env_num, n_episode)) else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) + raise ValueError("Either n_step or n_episode should be set.") start_time = time.time() - - if reset_before_collect: - self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( "Initial obs and info should not be None. " @@ -433,7 +507,6 @@ def collect( ) = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, - use_grad=use_grad, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, @@ -482,7 +555,8 @@ def collect( step_count += len(ready_env_ids_R) # preparing for the next iteration - # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + # obs_next, info and hidden_state will be modified inplace in the code below, + # so we copy to not affect the data in the buffer last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) last_hidden_state_RH = copy(hidden_state_RH) @@ -500,6 +574,7 @@ def collect( # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, @@ -577,8 +652,8 @@ def collect( collect_speed=step_count / collect_time, ) + @staticmethod def _reset_hidden_state_based_on_type( - self, env_ind_local_D: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, ) -> None: @@ -596,8 +671,7 @@ def _reset_hidden_state_based_on_type( class AsyncCollector(Collector): """Async Collector handles async vector environment. - The arguments are exactly the same as :class:`~tianshou.data.Collector`, please - refer to :class:`~tianshou.data.Collector` for more detailed explanation. + Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation. """ def __init__( @@ -607,6 +681,12 @@ def __init__( buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: + if not env.is_async: + # TODO: raise an exception? + log.error( + f"Please use {Collector.__name__} if not using async venv. " + f"Env class: {env.__class__.__name__}", + ) # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( @@ -627,22 +707,15 @@ def __init__( self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) self._current_policy_in_all_envs_E: Batch | None = None + @override def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: - """Reset the environment, statistics, and data needed to start the collection. - - :param reset_buffer: if true, reset the replay buffer attached - to the collector. - :param reset_stats: if true, reset the statistics attached to the collector. - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - """ + ) -> tuple[np.ndarray, np.ndarray]: # This sets the _pre_collect attrs - super().reset( + result = super().reset( reset_buffer=reset_buffer, reset_stats=reset_stats, gym_reset_kwargs=gym_reset_kwargs, @@ -655,69 +728,27 @@ def reset( self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) self._current_action_in_all_envs_EA = np.empty(self.env_num) self._current_policy_in_all_envs_E = None + return result - def collect( + @override + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + # we need to step through the envs and wait until they are ready to be able to interact with them + if self.env.waiting_id: + self.env.step(None, id=self.env.waiting_id) + return super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + + @override + def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, - reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of steps or episodes with async env setting. - - This function does not collect an exact number of transitions specified by n_step or - n_episode. Instead, to support the asynchronous setting, it may collect more transitions - than requested by n_step or n_episode and save them into the buffer. - - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy_R for collecting data. Default - to False. - :param render: the sleep time between rendering consecutive frames. - Default to None (no rendering). - :param no_grad: whether to retain gradient in policy_R.forward(). Default to - True (no gradient retaining). - :param reset_before_collect: whether to reset the environment before - collecting data. It has only an effect if n_episode is not None, i.e. - if one wants to collect a fixed number of episodes. - (The collector needs the initial obs and info to function properly.) - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. - - :return: A dataclass object - """ - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} - - # collect at least n_step or n_episode - if n_step is not None: - assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." - ) - assert n_step > 0 - elif n_episode is not None: - assert n_episode > 0 - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) - - if reset_before_collect: - # first we need to step all envs to be able to interact with them - if self.env.waiting_id: - self.env.step(None, id=self.env.waiting_id) - self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - start_time = time.time() step_count = 0 @@ -775,7 +806,6 @@ def collect( ) = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, - use_grad=use_grad, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, @@ -847,12 +877,12 @@ def collect( num_collected_episodes += num_episodes_done_this_iter # preparing for the next iteration - # todo do we need the copy stuff (tests pass also without) # todo seem we can get rid of this last_sth stuff altogether last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) - last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] - + last_hidden_state_RH = copy( + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index] + ) if num_episodes_done_this_iter: env_ind_local_D = np.where(done_R)[0] env_ind_global_D = ready_env_ids_R[env_ind_local_D] @@ -862,6 +892,7 @@ def collect( # now we copy obs_next_RO to obs, but since there might be # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index b1ce8362e..b77318602 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin): """The statistics of the last call to the training collector.""" test_collect_stat: Optional["CollectStats"] """The statistics of the last call to the test collector.""" - training_stat: "TrainingStats" - """The statistics of the last model update step.""" + training_stat: Optional["TrainingStats"] + """The statistics of the last model update step. + Can be None if no model update is performed, typically in the last training iteration.""" info_stat: InfoStats """The information of the collector.""" diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 8f1758e62..ac35ccf3f 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -6,7 +6,6 @@ import numpy as np from tianshou.env.utils import gym_new_venv_step_type -from tianshou.utils import deprecation class EnvWorker(ABC): @@ -27,6 +26,7 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: pass + @abstractmethod def send(self, action: np.ndarray | None) -> None: """Send action signal to low-level worker. @@ -34,17 +34,6 @@ def send(self, action: np.ndarray | None) -> None: it indicates "step" signal. The paired return value from "recv" function is determined by such kind of different signal. """ - if hasattr(self, "send_action"): - deprecation( - "send_action will soon be deprecated. " - "Please use send and recv for your own EnvWorker.", - ) - if action is None: - self.is_reset = True - self.result = self.reset() - else: - self.is_reset = False - self.send_action(action) def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: """Receive result from low-level worker. @@ -54,13 +43,6 @@ def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: info) or (obs, rew, terminated, truncated, info), based on whether the environment is using the old step API or the new one. """ - if hasattr(self, "get_result"): - deprecation( - "get_result will soon be deprecated. " - "Please use send and recv for your own EnvWorker.", - ) - if not self.is_reset: - self.result = self.get_result() return self.result @abstractmethod diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f71a7f981..c1313262e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -6,6 +6,7 @@ import gymnasium from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.collector import BaseCollector from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ( @@ -94,7 +95,7 @@ def create_train_test_collector( policy: BasePolicy, envs: Environments, reset_collectors: bool = True, - ) -> tuple[Collector, Collector]: + ) -> tuple[BaseCollector, BaseCollector]: """:param policy: :param envs: :param reset_collectors: Whether to reset the collectors before returning them. diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 8fc21cfad..99aadc23f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -345,7 +345,6 @@ def _watch_agent( env: BaseVectorEnv, render: float, ) -> None: - policy.eval() collector = Collector(policy, env) collector.reset() result = collector.collect(n_episode=num_episodes, render=render) diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 1a8d64872..c32ef9cbc 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from tianshou.data import Collector + from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import BasePolicy @@ -16,8 +16,8 @@ class World: envs: "Environments" policy: "BasePolicy" - train_collector: "Collector" - test_collector: "Collector" + train_collector: "BaseCollector" + test_collector: "BaseCollector" logger: "TLogger" persist_directory: str restore_directory: str | None diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e853efd26..b7ae5f23d 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,6 +25,7 @@ ) from tianshou.utils import MultipleLRSchedulers from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode logger = logging.getLogger(__name__) @@ -213,20 +214,41 @@ def __init__( super().__init__() self.observation_space = observation_space self.action_space = action_space - self._action_type: Literal["discrete", "continuous"] if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): - self._action_type = "discrete" + action_type = "discrete" elif isinstance(action_space, Box): - self._action_type = "continuous" + action_type = "continuous" else: raise ValueError(f"Unsupported action space: {action_space}.") + self._action_type = cast(Literal["discrete", "continuous"], action_type) self.agent_id = 0 self.updating = False self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.lr_scheduler = lr_scheduler + self.is_within_training_step = False + """ + flag indicating whether we are currently within a training step, + which encompasses data collection for training (in online RL algorithms) + and the policy update (gradient steps). + + It can be used, for example, to control whether a flag controlling deterministic evaluation should + indeed be applied, because within a training step, we typically always want to apply stochastic evaluation + (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC + based algorithms). + + This flag should normally remain False and should be set to True only by the algorithm which performs + training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, + the user should ensure that this flag is set correctly before calling update or learn. + """ self._compile() + def __setstate__(self, state: dict[str, Any]) -> None: + # TODO Use setstate function once merged + if "is_within_training_step" not in state: + state["is_within_training_step"] = False + self.__dict__ = state + @property def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type @@ -505,13 +527,22 @@ def update( """ # TODO: when does this happen? # -> this happens never in practice as update is either called with a collector buffer or an assert before + + if not self.is_within_training_step: + raise RuntimeError( + f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " + f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " + f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", + ) + if buffer is None: return TrainingStats() # type: ignore[return-value] start_time = time.time() batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - training_stat = self.learn(batch, **kwargs) + with torch_train_mode(self): + training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index e9f9b3b4a..d1ce28da9 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,10 +107,11 @@ def forward( # type: ignore ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - if self.deterministic_eval and not self.training: - act_B = dist.mode - else: - act_B = dist.sample() + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9a148feb7..80bcff672 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,10 +197,11 @@ def forward( # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - if self.deterministic_eval and not self.training: - act_B = dist.mode - else: - act_B = dist.sample() + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index f9793f4db..25f299733 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -153,7 +153,7 @@ def forward( # type: ignore ) -> Batch: (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and not self.training: + if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3b3975473..a5a05c0fd 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -175,7 +175,7 @@ def forward( # type: ignore ) -> DistLogProbBatchProtocol: (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and not self.training: + if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 675112fae..242f2b028 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -10,14 +10,13 @@ from tianshou.data import ( AsyncCollector, - Collector, CollectStats, EpochStats, InfoStats, ReplayBuffer, SequenceSummaryStats, ) -from tianshou.data.collector import CollectStatsBase +from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.policy import BasePolicy from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode @@ -26,10 +25,10 @@ DummyTqdm, LazyLogger, MovAvg, - deprecation, tqdm_config, ) from tianshou.utils.logging import set_numerical_fields_to_precision +from tianshou.utils.torch_utils import policy_within_training_step log = logging.getLogger(__name__) @@ -76,7 +75,7 @@ class BaseTrainer(ABC): signature ``f(num_epoch: int, step_idx: int) -> None``. :param save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. + ``f(policy: BasePolicy) -> None``. :param save_checkpoint_fn: a function to save training process and return the saved checkpoint path, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> str``; you can save whatever you want. @@ -153,8 +152,8 @@ def __init__( policy: BasePolicy, max_epoch: int, batch_size: int | None, - train_collector: Collector | None = None, - test_collector: Collector | None = None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, @@ -173,16 +172,7 @@ def __init__( verbose: bool = True, show_progress: bool = True, test_in_train: bool = True, - save_fn: Callable[[BasePolicy], None] | None = None, ): - if save_fn: - deprecation( - "save_fn in trainer is marked as deprecated and will be " - "removed in the future. Please use save_best_fn instead.", - ) - assert save_best_fn is None - save_best_fn = save_fn - self.policy = policy if buffer is not None: @@ -269,7 +259,6 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No assert self.episode_per_test is not None assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 test_result = test_episode( - self.policy, self.test_collector, self.test_fn, self.start_epoch, @@ -309,17 +298,15 @@ def __next__(self) -> EpochStats: if self.stop_fn_flag: raise StopIteration - # set policy in train mode - self.policy.train() - progress = tqdm.tqdm if self.show_progress else DummyTqdm # perform n step_per_epoch with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - if self.train_collector is not None: - train_stat, self.stop_fn_flag = self.train_step() + train_stat, update_stat, self.stop_fn_flag = self.training_step() + + if isinstance(train_stat, CollectStats): pbar_data_dict = { "env_step": str(self.env_step), "rew": f"{self.last_rew:.2f}", @@ -328,23 +315,17 @@ def __next__(self) -> EpochStats: "n/st": str(train_stat.n_collected_steps), } t.update(train_stat.n_collected_steps) - if self.stop_fn_flag: - t.set_postfix(**pbar_data_dict) - break else: pbar_data_dict = {} - assert self.buffer, "No train_collector or buffer specified" - train_stat = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) t.update() - update_stat = self.policy_update_fn(train_stat) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict["gradient_step"] = str(self._gradient_step) - t.set_postfix(**pbar_data_dict) + if self.stop_fn_flag: + break + if t.n <= t.total and not self.stop_fn_flag: t.update() @@ -379,7 +360,7 @@ def __next__(self) -> EpochStats: self.logger.log_info_data(asdict(info_stat), self.epoch) # in case trainer is used with run(), epoch_stat will not be returned - epoch_stat: EpochStats = EpochStats( + return EpochStats( epoch=self.epoch, train_collect_stat=train_stat, test_collect_stat=test_stat, @@ -387,15 +368,12 @@ def __next__(self) -> EpochStats: info_stat=info_stat, ) - return epoch_stat - def test_step(self) -> tuple[CollectStats, bool]: """Perform one testing step.""" assert self.episode_per_test is not None assert self.test_collector is not None stop_fn_flag = False test_stat = test_episode( - self.policy, self.test_collector, self.test_fn, self.epoch, @@ -426,64 +404,110 @@ def test_step(self) -> tuple[CollectStats, bool]: return test_stat, stop_fn_flag - def train_step(self) -> tuple[CollectStats, bool]: - """Perform one training step. + def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: + """Perform one training iteration. - If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. - Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return - on it. - Finally, if the latter is also True, will set should_stop_training to True. + A training iteration includes collecting data (for online RL), determining whether to stop training, + and performing a policy update if the training iteration should continue. + + :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. + If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. + """ + with policy_within_training_step(self.policy): + should_stop_training = False + + collect_stats: CollectStatsBase | CollectStats + if self.train_collector is not None: + collect_stats = self._collect_training_data() + should_stop_training = self._update_best_reward_and_return_should_stop_training( + collect_stats, + ) + else: + assert self.buffer is not None, "Either train_collector or buffer must be provided." + collect_stats = CollectStatsBase( + n_collected_episodes=len(self.buffer), + ) + + if not should_stop_training: + training_stats = self.policy_update_fn(collect_stats) + else: + training_stats = None + + return collect_stats, training_stats, should_stop_training - :return: A tuple of the training stats and a boolean indicating whether to stop training. + def _collect_training_data(self) -> CollectStats: + """Performs training data collection. + + :return: the data collection stats """ assert self.episode_per_test is not None assert self.train_collector is not None - should_stop_training = False if self.train_fn: self.train_fn(self.epoch, self.env_step) - result = self.train_collector.collect( + collect_stats = self.train_collector.collect( n_step=self.step_per_collect, n_episode=self.episode_per_collect, ) - self.env_step += result.n_collected_steps + self.env_step += collect_stats.n_collected_steps - if result.n_collected_episodes > 0: - assert result.returns_stat is not None # for mypy - assert result.lens_stat is not None # for mypy - self.last_rew = result.returns_stat.mean - self.last_len = result.lens_stat.mean + if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None # for mypy + assert collect_stats.lens_stat is not None # for mypy + self.last_rew = collect_stats.returns_stat.mean + self.last_len = collect_stats.lens_stat.mean if self.reward_metric: # TODO: move inside collector - rew = self.reward_metric(result.returns) - result.returns = rew - result.returns_stat = SequenceSummaryStats.from_sequence(rew) + rew = self.reward_metric(collect_stats.returns) + collect_stats.returns = rew + collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) - self.logger.log_train_data(asdict(result), self.env_step) + self.logger.log_train_data(asdict(collect_stats), self.env_step) - if ( - result.n_collected_episodes > 0 - and self.test_in_train - and self.stop_fn - and self.stop_fn(result.returns_stat.mean) # type: ignore - ): - assert self.test_collector is not None - test_result = test_episode( - self.policy, - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, - ) - assert test_result.returns_stat is not None # for mypy - if self.stop_fn(test_result.returns_stat.mean): - should_stop_training = True - self.best_reward = test_result.returns_stat.mean - self.best_reward_std = test_result.returns_stat.std - else: - self.policy.train() - return result, should_stop_training + return collect_stats + + # TODO (maybe): separate out side effect, simplify name? + def _update_best_reward_and_return_should_stop_training( + self, + collect_stats: CollectStats, + ) -> bool: + """If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data. + Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return + on it. + Finally, if the latter is also True, will return True. + + **NOTE:** has a side effect of updating the best reward and corresponding std. + + + :param collect_stats: the data collection stats + :return: flag indicating whether to stop training + """ + should_stop_training = False + + # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics + with policy_within_training_step(self.policy, enabled=False): + if ( + collect_stats.n_collected_episodes > 0 + and self.test_in_train + and self.stop_fn + and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore + ): + assert self.test_collector is not None + assert self.episode_per_test is not None and self.episode_per_test > 0 + test_result = test_episode( + self.test_collector, + self.test_fn, + self.epoch, + self.episode_per_test, + self.logger, + self.env_step, + ) + assert test_result.returns_stat is not None # for mypy + if self.stop_fn(test_result.returns_stat.mean): + should_stop_training = True + self.best_reward = test_result.returns_stat.mean + self.best_reward_std = test_result.returns_stat.std + + return should_stop_training # TODO: move moving average computation and logging into its own logger # TODO: maybe think about a command line logger instead of always printing data dict diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 7a96ea06f..de730cee2 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -5,19 +5,17 @@ import numpy as np from tianshou.data import ( - Collector, CollectStats, InfoStats, SequenceSummaryStats, TimingStats, ) -from tianshou.policy import BasePolicy +from tianshou.data.collector import BaseCollector from tianshou.utils import BaseLogger def test_episode( - policy: BasePolicy, - collector: Collector, + collector: BaseCollector, test_fn: Callable[[int, int | None], None] | None, epoch: int, n_episode: int, @@ -27,7 +25,6 @@ def test_episode( ) -> CollectStats: """A simple wrapper of testing policy in collector.""" collector.reset(reset_stats=False) - policy.eval() if test_fn: test_fn(epoch, global_step) result = collector.collect(n_episode=n_episode) @@ -47,28 +44,14 @@ def gather_info( gradient_step: int, best_reward: float, best_reward_std: float, - train_collector: Collector | None = None, - test_collector: Collector | None = None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, ) -> InfoStats: """A simple wrapper of gathering information from collectors. - :return: A dataclass object with the following members (depending on available collectors): - - * ``gradient_step`` the total number of gradient steps; - * ``best_reward`` the best reward over the test results; - * ``best_reward_std`` the standard deviation of best reward over the test results; - * ``train_step`` the total collected step of training collector; - * ``train_episode`` the total collected episode of training collector; - * ``test_step`` the total collected step of test collector; - * ``test_episode`` the total collected episode of test collector; - * ``timing`` the timing statistics, with the following members: - * ``total_time`` the total time elapsed; - * ``train_time`` the total time elapsed for learning training (collecting samples plus model update); - * ``train_time_collect`` the time for collecting transitions in the \ - training collector; - * ``train_time_update`` the time for training models; - * ``test_time`` the time for testing; - * ``update_speed`` the speed of updating (env_step per second). + :return: InfoStats object with times computed based on the `start_time` and + episode/step counts read off the collectors. No computation of + expensive statistics is done here. """ duration = max(0.0, time.time() - start_time) test_time = 0.0 diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 66a7a8db8..47a3c4497 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -12,11 +12,11 @@ "MovAvg", "RunningMeanStd", "tqdm_config", + "deprecation", "DummyTqdm", "BaseLogger", "TensorboardLogger", "LazyLogger", "WandbLogger", - "deprecation", "MultipleLRSchedulers", ] diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py new file mode 100644 index 000000000..430d174e7 --- /dev/null +++ b/tianshou/utils/torch_utils.py @@ -0,0 +1,39 @@ +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from torch import nn + +if TYPE_CHECKING: + from tianshou.policy import BasePolicy + + +@contextmanager +def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`.""" + original_mode = module.training + try: + module.train(enabled) + yield + finally: + module.train(original_mode) + + +@contextmanager +def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `policy.is_within_training_step=enabled`. + + Enabling this ensures that the policy is able to adapt its behavior, + allowing it to differentiate between training and inference/evaluation, + e.g., to sample actions instead of using the most probable action (where applicable) + Note that for rollout, which also happens within a training step, one would usually want + the wrapped torch module to be in evaluation mode, which can be achieved using + `with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both + within training step and in torch train mode. + """ + original_mode = policy.is_within_training_step + try: + policy.is_within_training_step = enabled + yield + finally: + policy.is_within_training_step = original_mode