From e499bed8b031592c4da5c4f897ab1ffeec74c03a Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 24 Apr 2024 17:06:42 +0200 Subject: [PATCH 01/32] add is_eval attribute to policy and set this attribute as well as train mode in appropriate places --- tianshou/data/collector.py | 24 +++++++++++++++++++++++ tianshou/highlevel/experiment.py | 3 +-- tianshou/policy/base.py | 2 ++ tianshou/policy/modelfree/a2c.py | 3 +++ tianshou/policy/modelfree/bdq.py | 3 +++ tianshou/policy/modelfree/c51.py | 2 ++ tianshou/policy/modelfree/discrete_sac.py | 8 ++++---- tianshou/policy/modelfree/dqn.py | 2 ++ tianshou/policy/modelfree/fqf.py | 2 ++ tianshou/policy/modelfree/iqn.py | 2 ++ tianshou/policy/modelfree/pg.py | 8 ++++---- tianshou/policy/modelfree/ppo.py | 3 +++ tianshou/policy/modelfree/qrdqn.py | 2 ++ tianshou/policy/modelfree/redq.py | 2 +- tianshou/policy/modelfree/sac.py | 5 ++++- tianshou/trainer/base.py | 8 -------- tianshou/trainer/utils.py | 5 +---- 17 files changed, 60 insertions(+), 24 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 345d50b03..1a60bd083 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -318,6 +318,7 @@ def collect( no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, + is_eval: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes. @@ -334,6 +335,7 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. + :param is_eval: whether to collect data in evaluation mode. .. note:: @@ -356,6 +358,13 @@ def collect( # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. + # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy + # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on + # policy.deterministic_eval) + self.policy.eval() + pre_collect_is_eval = self.policy.is_eval + self.policy.is_eval = is_eval + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} @@ -568,6 +577,9 @@ def collect( # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? + # set the policy back to pre collect mode + self.policy.is_eval = pre_collect_is_eval + return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), @@ -665,6 +677,7 @@ def collect( no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, + is_eval: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes with async env setting. @@ -686,6 +699,7 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) + :param is_eval: whether to collect data in evaluation mode. .. note:: @@ -694,6 +708,13 @@ def collect( :return: A dataclass object """ + # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy + # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on + # policy.deterministic_eval) + self.policy.eval() + pre_collect_is_eval = self.policy.is_eval + self.policy.is_eval = is_eval + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} @@ -902,6 +923,9 @@ def collect( # persist for future collect iterations self._ready_env_ids_R = ready_env_ids_R + # set the policy back to pre collect mode + self.policy.is_eval = pre_collect_is_eval + return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 6f9eb7c00..71b8159a0 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -335,10 +335,9 @@ def _watch_agent( env: BaseVectorEnv, render: float, ) -> None: - policy.eval() collector = Collector(policy, env) collector.reset() - result = collector.collect(n_episode=num_episodes, render=render) + result = collector.collect(n_episode=num_episodes, render=render, is_eval=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 77602a02b..450bf5228 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -225,6 +225,8 @@ def __init__( self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.lr_scheduler = lr_scheduler + # whether the policy is in evaluation mode + self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward? self._compile() @property diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index d41ccb463..f055c4122 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -165,6 +165,9 @@ def learn( # type: ignore *args: Any, **kwargs: Any, ) -> TA2CTrainingStats: + # set policy in train mode + self.train() + losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index d7196a92b..80c17bef7 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -163,6 +163,9 @@ def forward( return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: + # set policy in train mode + self.train() + if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 5bfdba0c1..d406dda56 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -117,6 +117,8 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: return target_dist.sum(-1) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index e9f9b3b4a..d236e442e 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,10 +107,7 @@ def forward( # type: ignore ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - if self.deterministic_eval and not self.training: - act_B = dist.mode - else: - act_B = dist.sample() + act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: @@ -127,6 +124,9 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q.sum(dim=-1) + self.alpha * dist.entropy() def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore + # set policy in train mode + self.train() + weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index e0ada0733..d2e7910bd 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -210,6 +210,8 @@ def forward( return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9c87f9cac..c4a1a2d01 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -153,6 +153,8 @@ def forward( # type: ignore return cast(FQFBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 75d76a2dd..f868ce41e 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -131,6 +131,8 @@ def forward( return cast(QuantileRegressionBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9a148feb7..b86540da4 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,10 +197,7 @@ def forward( # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - if self.deterministic_eval and not self.training: - act_B = dist.mode - else: - act_B = dist.sample() + act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) @@ -214,6 +211,9 @@ def learn( # type: ignore *args: Any, **kwargs: Any, ) -> TPGTrainingStats: + # set policy in train mode + self.train() + losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 196cd72e4..298711475 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -151,6 +151,9 @@ def learn( # type: ignore *args: Any, **kwargs: Any, ) -> TPPOTrainingStats: + # set policy in train mode + self.train() + losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 71c36de0c..9f3e1626c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -105,6 +105,8 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc return super().compute_q_value(logits.mean(2), mask) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index f9793f4db..a216cf9fe 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -153,7 +153,7 @@ def forward( # type: ignore ) -> Batch: (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and not self.training: + if self.deterministic_eval and self.is_eval: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3b3975473..f9ddf7cbc 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -175,7 +175,7 @@ def forward( # type: ignore ) -> DistLogProbBatchProtocol: (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and not self.training: + if self.deterministic_eval and self.is_eval: act_B = dist.mode else: act_B = dist.rsample() @@ -213,6 +213,9 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: ) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore + # set policy in train mode + self.train() + # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 675112fae..fa67d3bfb 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -269,7 +269,6 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No assert self.episode_per_test is not None assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 test_result = test_episode( - self.policy, self.test_collector, self.test_fn, self.start_epoch, @@ -309,9 +308,6 @@ def __next__(self) -> EpochStats: if self.stop_fn_flag: raise StopIteration - # set policy in train mode - self.policy.train() - progress = tqdm.tqdm if self.show_progress else DummyTqdm # perform n step_per_epoch @@ -395,7 +391,6 @@ def test_step(self) -> tuple[CollectStats, bool]: assert self.test_collector is not None stop_fn_flag = False test_stat = test_episode( - self.policy, self.test_collector, self.test_fn, self.epoch, @@ -468,7 +463,6 @@ def train_step(self) -> tuple[CollectStats, bool]: ): assert self.test_collector is not None test_result = test_episode( - self.policy, self.test_collector, self.test_fn, self.epoch, @@ -481,8 +475,6 @@ def train_step(self) -> tuple[CollectStats, bool]: should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std - else: - self.policy.train() return result, should_stop_training # TODO: move moving average computation and logging into its own logger diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 7a96ea06f..4d990db4b 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -11,12 +11,10 @@ SequenceSummaryStats, TimingStats, ) -from tianshou.policy import BasePolicy from tianshou.utils import BaseLogger def test_episode( - policy: BasePolicy, collector: Collector, test_fn: Callable[[int, int | None], None] | None, epoch: int, @@ -27,10 +25,9 @@ def test_episode( ) -> CollectStats: """A simple wrapper of testing policy in collector.""" collector.reset(reset_stats=False) - policy.eval() if test_fn: test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode) + result = collector.collect(n_episode=n_episode, is_eval=True) if reward_metric: # TODO: move into collector rew = reward_metric(result.returns) result.returns = rew From 8cb17de190e03631a3c8375b140364bdcd9255f7 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 24 Apr 2024 17:06:54 +0200 Subject: [PATCH 02/32] update examples --- examples/atari/atari_c51.py | 9 ++++++--- examples/atari/atari_dqn.py | 9 ++++++--- examples/atari/atari_fqf.py | 9 ++++++--- examples/atari/atari_iqn.py | 9 ++++++--- examples/atari/atari_ppo.py | 9 ++++++--- examples/atari/atari_qrdqn.py | 9 ++++++--- examples/atari/atari_rainbow.py | 9 ++++++--- examples/atari/atari_sac.py | 9 ++++++--- examples/box2d/acrobot_dualdqn.py | 7 +++++-- examples/box2d/bipedal_bdq.py | 7 +++++-- examples/box2d/bipedal_hardcore_sac.py | 7 +++++-- examples/box2d/lunarlander_dqn.py | 7 +++++-- examples/box2d/mcc_sac.py | 7 +++++-- examples/discrete/discrete_dqn.py | 1 - examples/inverse/irl_gail.py | 7 +++++-- examples/mujoco/fetch_her_ddpg.py | 7 +++++-- examples/mujoco/mujoco_a2c.py | 7 +++++-- examples/mujoco/mujoco_ddpg.py | 7 +++++-- examples/mujoco/mujoco_npg.py | 7 +++++-- examples/mujoco/mujoco_ppo.py | 7 +++++-- examples/mujoco/mujoco_redq.py | 7 +++++-- examples/mujoco/mujoco_reinforce.py | 7 +++++-- examples/mujoco/mujoco_sac.py | 7 +++++-- examples/mujoco/mujoco_td3.py | 7 +++++-- examples/mujoco/mujoco_trpo.py | 7 +++++-- examples/offline/atari_bcq.py | 3 +-- examples/offline/atari_cql.py | 3 +-- examples/offline/atari_crr.py | 3 +-- examples/offline/atari_il.py | 3 +-- examples/offline/d4rl_bcq.py | 10 ++++++---- examples/offline/d4rl_cql.py | 10 ++++++---- examples/offline/d4rl_il.py | 10 ++++++---- examples/offline/d4rl_td3_bc.py | 10 ++++++---- examples/vizdoom/vizdoom_c51.py | 9 ++++++--- examples/vizdoom/vizdoom_ppo.py | 9 ++++++--- 35 files changed, 168 insertions(+), 87 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index fc04a219a..16694af3d 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -162,7 +162,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -175,14 +174,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index f237c5a33..aed46a2a3 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -204,7 +204,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -217,14 +216,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 127a14b24..92a614007 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -175,7 +175,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -188,14 +187,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 8b1625275..3c4a695bc 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -172,7 +172,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -185,14 +184,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index f1a89ef40..3c2c68d9b 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -229,7 +229,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -241,14 +240,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index dfb96419e..c821fad40 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -166,7 +166,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -179,14 +178,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 7b341c0a1..0d14f5d9f 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -200,7 +200,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -215,14 +214,18 @@ def watch() -> None: beta=args.beta, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index f06964c28..a8d759ad8 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -216,7 +216,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -228,14 +227,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 96e61b612..ed6a65ba4 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -144,11 +144,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 8b1e8ca8d..da425e623 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -162,11 +162,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 2c071bc1c..8dfbed69f 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -207,10 +207,13 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 9e5db5833..4bd289644 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -141,11 +141,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 5c093093e..257641187 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -153,10 +153,13 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 4f1a82b12..3ba22a40c 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -80,7 +80,6 @@ def stop_fn(mean_rewards: float) -> bool: print(f"Finished training in {result.timing.total_time} seconds") # watch performance - policy.eval() policy.set_eps(eps_test) collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 2d013a01b..3c1035927 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -264,10 +264,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index cd6ceec89..01a89b610 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -238,10 +238,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index ea6ab8f24..0195702a1 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -219,10 +219,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index b2a40878b..98852dc6b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -168,10 +168,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 8a379da92..c2a193a22 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -216,10 +216,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 00042884f..9e20c467c 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -224,10 +224,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index ae46b220c..791005c73 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -196,10 +196,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 109f1cc46..8912ee64a 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -196,10 +196,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index a0bd567ff..6c0bbdc95 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -190,10 +190,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 6b6dfdc8c..8e0c31576 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -188,10 +188,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 219593343..96da0cfe6 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -221,10 +221,13 @@ def save_best_fn(policy: BasePolicy) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 1fc0dc7e3..80687d163 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -187,12 +187,11 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 40d91c1bb..61071453f 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -171,12 +171,11 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index a4b31c4fb..ad5c2312f 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -188,11 +188,10 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index bb7822ea9..495b630b3 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -145,11 +145,10 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 80b233cb7..983424730 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -206,9 +206,8 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -229,10 +228,13 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 7ca8ae2fb..4e6a127aa 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -344,9 +344,8 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -367,10 +366,13 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index c2152a711..a8a6a66bb 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -142,9 +142,8 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -165,10 +164,13 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 4d6159ff5..24138a4b8 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -191,9 +191,8 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -217,10 +216,13 @@ def watch() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 4211585af..356605151 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -168,7 +168,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -181,14 +180,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index f5abf0b6f..907b0ebfe 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -231,7 +231,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -243,14 +242,18 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: From 49c750fb0920b34221a249205b3f39c5a5a2a81b Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 24 Apr 2024 17:06:59 +0200 Subject: [PATCH 03/32] update tests --- test/base/test_policy.py | 1 + test/continuous/test_ddpg.py | 3 +- test/continuous/test_npg.py | 3 +- test/continuous/test_ppo.py | 3 +- test/continuous/test_redq.py | 3 +- test/continuous/test_sac_with_il.py | 1 - test/continuous/test_td3.py | 3 +- test/continuous/test_trpo.py | 3 +- test/discrete/test_a2c_with_il.py | 73 +++++++++++++++--------- test/discrete/test_bdq.py | 7 ++- test/discrete/test_c51.py | 8 +-- test/discrete/test_dqn.py | 8 +-- test/discrete/test_drqn.py | 8 +-- test/discrete/test_fqf.py | 8 +-- test/discrete/test_iqn.py | 8 +-- test/discrete/test_pg.py | 8 +-- test/discrete/test_ppo.py | 8 +-- test/discrete/test_qrdqn.py | 10 ++-- test/discrete/test_rainbow.py | 8 +-- test/discrete/test_sac.py | 8 +-- test/highlevel/env_factory.py | 2 +- test/modelbased/test_dqn_icm.py | 7 +-- test/modelbased/test_ppo_icm.py | 7 +-- test/modelbased/test_psrl.py | 3 +- test/offline/gather_cartpole_data.py | 8 +-- test/offline/test_bcq.py | 6 +- test/offline/test_cql.py | 3 +- test/offline/test_discrete_bcq.py | 7 +-- test/offline/test_discrete_cql.py | 7 +-- test/offline/test_discrete_crr.py | 7 +-- test/offline/test_gail.py | 3 +- test/offline/test_td3_bc.py | 3 +- test/pettingzoo/pistonball.py | 3 +- test/pettingzoo/pistonball_continuous.py | 3 +- test/pettingzoo/tic_tac_toe.py | 3 +- 35 files changed, 128 insertions(+), 126 deletions(-) diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 7c3aacc07..f286156ed 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -64,6 +64,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: class TestPolicyBasics: def test_get_action(self, policy: PPOPolicy) -> None: + policy.is_eval = True sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index a17c3b513..4b776ce8e 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -138,9 +138,8 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 8e0a50d2c..c60aa2ba0 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -160,9 +160,8 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 5a522dedb..33bcc55f3 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -195,9 +195,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: pprint.pprint(epoch_stat) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 697b59e98..a97b4432a 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -169,9 +169,8 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index d13b03d85..0bf80e50c 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -161,7 +161,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy - policy.eval() if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal il_net = Net( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ea55da052..d5efc5746 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -160,10 +160,9 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(epoch_stat.info_stat) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index ae788d1cc..c0debf71c 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -160,9 +160,8 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f60857ea4..a2b9fb419 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -4,12 +4,12 @@ import gymnasium as gym import numpy as np -import pytest import torch from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer @@ -25,7 +25,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -60,29 +60,35 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: - # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make( - args.task, - env_type="gymnasium", - num_envs=args.training_num, - seed=args.seed, - ) - test_envs = envpool.make( - args.task, - env_type="gymnasium", - num_envs=args.test_num, - seed=args.seed, - ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + if envpool is not None: + train_envs = env = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.training_num, + seed=args.seed, + ) + test_envs = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.test_num, + seed=args.seed, + ) + else: + env = gym.make(args.task) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs.seed(args.seed) + test_envs.seed(args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) @@ -145,14 +151,13 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) - policy.eval() # here we define an imitation collector with a trivial policy - # if args.task == 'CartPole-v0': + # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) @@ -162,9 +167,23 @@ def stop_fn(mean_rewards: float) -> bool: optim=optim, action_space=env.action_space, ) + if envpool is not None: + il_env = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.test_num, + seed=args.seed, + ) + else: + il_env = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + context="fork", + ) + il_env.seed(args.seed) + il_test_collector = Collector( il_policy, - envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed), + il_env, ) train_collector.reset() result = OffpolicyTrainer( @@ -186,9 +205,9 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - il_policy.eval() collector = Collector(il_policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 1089d4ba0..6c93b0230 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -148,11 +148,14 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) collector_stats.pprint_asdict() diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 4d25d430b..3abd45ba6 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -25,7 +25,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -68,7 +68,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -206,10 +206,10 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b62a93c3f..cd8b3f50d 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -24,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -159,10 +159,10 @@ def test_fn(epoch: int, env_step: int | None) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5c24518bb..bae341f38 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -19,7 +19,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -55,7 +55,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -136,9 +136,9 @@ def test_fn(epoch: int, env_step: int | None) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 8ff9eeb7a..399672ea4 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -25,7 +25,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -176,10 +176,10 @@ def test_fn(epoch: int, env_step: int | None) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 765bbf9bd..f11ea466a 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -25,7 +25,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -172,10 +172,10 @@ def test_fn(epoch: int, env_step: int | None) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 95db43c23..51142eb2e 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -20,7 +20,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -51,7 +51,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -129,9 +129,9 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 132cbea5a..03f6f8df3 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -23,7 +23,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer-size", type=int, default=20000) @@ -64,7 +64,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -156,9 +156,9 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 6485637e8..6d39d9a0d 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -20,7 +20,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -60,10 +60,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - if args.task == "CartPole-v0" and env.spec: + if args.task == "CartPole-v1" and env.spec: env.spec.reward_threshold = 190 # lower the goal if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -161,10 +161,10 @@ def test_fn(epoch: int, env_step: int | None) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index ff4ef1c1e..0dabd9e90 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -22,7 +22,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -69,7 +69,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -223,10 +223,10 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index b2f466f3d..cdc53d749 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -21,7 +21,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -60,7 +60,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} # lower the goal + default_reward_threshold = {"CartPole-v1": 170} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -147,9 +147,9 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index ddfce7b23..4a131e5fd 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -7,7 +7,7 @@ class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( - task="CartPole-v0", + task="CartPole-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9a4206e18..01e6572a2 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -21,7 +21,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -79,7 +79,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -202,10 +202,9 @@ def test_fn(epoch: int, env_step: int | None) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index ebf93cd5a..fcb541e9d 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -22,7 +22,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer-size", type=int, default=20000) @@ -83,7 +83,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -194,9 +194,8 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 72742b785..ca7a38895 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -120,10 +120,9 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - stats = test_collector.collect(n_episode=args.test_num, render=args.render) + stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) stats.pprint_asdict() elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 93877944e..91ee284f9 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -19,12 +19,12 @@ def expert_file_name() -> str: - return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") + return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 190} + default_reward_threshold = {"CartPole-v1": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -167,7 +167,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) collector.reset() - collector_stats = collector.collect(n_step=args.buffer_size) + collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 1839d863a..660c607d2 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -189,9 +189,8 @@ def watch() -> None: policy.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) # trainer result = OfflineTrainer( @@ -213,9 +212,8 @@ def watch() -> None: if __name__ == "__main__": pprint.pprint(result) env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 1e31b1feb..53aaf1efa 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -210,9 +210,8 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render) + collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) if collector_result.returns_stat and collector_result.lens_stat: print( f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 77790808b..e151633a3 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -25,7 +25,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -61,7 +61,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 185} + default_reward_threshold = {"CartPole-v1": 185} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -169,10 +169,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 7323eac13..309f3b6e2 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -24,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -58,7 +58,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} + default_reward_threshold = {"CartPole-v1": 170} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -131,10 +131,9 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index b3cb64616..c3f254917 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -25,7 +25,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) @@ -56,7 +56,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 180} + default_reward_threshold = {"CartPole-v1": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -135,9 +135,8 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 256140c41..df0f7bfb1 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -231,9 +231,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 18778563c..43762ddfd 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -198,9 +198,8 @@ def stop_fn(mean_rewards: float) -> bool: if __name__ == "__main__": pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index abd0c889a..7fc913451 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -188,8 +188,7 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - policy.eval() [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render) + result = collector.collect(n_episode=1, render=args.render, is_eval=True) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 54d606602..010047058 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - policy.eval() collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render) + collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 7ed631912..b63636ae7 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -228,8 +228,7 @@ def watch( ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) - policy.eval() policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render) + result = collector.collect(n_episode=1, render=args.render, is_eval=True) result.pprint_asdict() From 829fd9c7a5a06039f4ad6791675e5d100fe8f0dd Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 14:29:16 +0200 Subject: [PATCH 04/32] Deleted long deprecated functionality, removed unused warning module There's better ways to deal with deprecations that we shall use in the future --- tianshou/env/worker/base.py | 20 +------------------- tianshou/trainer/base.py | 12 +----------- tianshou/utils/__init__.py | 2 -- tianshou/utils/warning.py | 8 -------- 4 files changed, 2 insertions(+), 40 deletions(-) delete mode 100644 tianshou/utils/warning.py diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 8f1758e62..ac35ccf3f 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -6,7 +6,6 @@ import numpy as np from tianshou.env.utils import gym_new_venv_step_type -from tianshou.utils import deprecation class EnvWorker(ABC): @@ -27,6 +26,7 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: pass + @abstractmethod def send(self, action: np.ndarray | None) -> None: """Send action signal to low-level worker. @@ -34,17 +34,6 @@ def send(self, action: np.ndarray | None) -> None: it indicates "step" signal. The paired return value from "recv" function is determined by such kind of different signal. """ - if hasattr(self, "send_action"): - deprecation( - "send_action will soon be deprecated. " - "Please use send and recv for your own EnvWorker.", - ) - if action is None: - self.is_reset = True - self.result = self.reset() - else: - self.is_reset = False - self.send_action(action) def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: """Receive result from low-level worker. @@ -54,13 +43,6 @@ def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: info) or (obs, rew, terminated, truncated, info), based on whether the environment is using the old step API or the new one. """ - if hasattr(self, "get_result"): - deprecation( - "get_result will soon be deprecated. " - "Please use send and recv for your own EnvWorker.", - ) - if not self.is_reset: - self.result = self.get_result() return self.result @abstractmethod diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index fa67d3bfb..3c8a27e21 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -26,7 +26,6 @@ DummyTqdm, LazyLogger, MovAvg, - deprecation, tqdm_config, ) from tianshou.utils.logging import set_numerical_fields_to_precision @@ -76,7 +75,7 @@ class BaseTrainer(ABC): signature ``f(num_epoch: int, step_idx: int) -> None``. :param save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. + ``f(policy: BasePolicy) -> None``. :param save_checkpoint_fn: a function to save training process and return the saved checkpoint path, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> str``; you can save whatever you want. @@ -173,16 +172,7 @@ def __init__( verbose: bool = True, show_progress: bool = True, test_in_train: bool = True, - save_fn: Callable[[BasePolicy], None] | None = None, ): - if save_fn: - deprecation( - "save_fn in trainer is marked as deprecated and will be " - "removed in the future. Please use save_best_fn instead.", - ) - assert save_best_fn is None - save_best_fn = save_fn - self.policy = policy if buffer is not None: diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 66a7a8db8..2f46f9ca5 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -6,7 +6,6 @@ from tianshou.utils.lr_scheduler import MultipleLRSchedulers from tianshou.utils.progress_bar import DummyTqdm, tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd -from tianshou.utils.warning import deprecation __all__ = [ "MovAvg", @@ -17,6 +16,5 @@ "TensorboardLogger", "LazyLogger", "WandbLogger", - "deprecation", "MultipleLRSchedulers", ] diff --git a/tianshou/utils/warning.py b/tianshou/utils/warning.py deleted file mode 100644 index 93c5ccec3..000000000 --- a/tianshou/utils/warning.py +++ /dev/null @@ -1,8 +0,0 @@ -import warnings - -warnings.simplefilter("once", DeprecationWarning) - - -def deprecation(msg: str) -> None: - """Deprecation warning wrapper.""" - warnings.warn(msg, category=DeprecationWarning, stacklevel=2) From 7d593020959f99804d4d962b30672e29094b2a4f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 14:45:02 +0200 Subject: [PATCH 05/32] Added in_eval/in_train mode contextmanager --- test/base/test_utils.py | 22 ++++++++++++++++------ tianshou/trainer/utils.py | 2 +- tianshou/utils/torch_utils.py | 25 +++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 tianshou/utils/torch_utils.py diff --git a/test/base/test_utils.py b/test/base/test_utils.py index bd14ffe2a..1d8e37c58 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,10 +1,12 @@ import numpy as np import torch +from torch import nn from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic +from tianshou.utils.torch_utils import in_eval_mode, in_train_mode def test_noise() -> None: @@ -132,9 +134,17 @@ def test_lr_schedulers() -> None: ) -if __name__ == "__main__": - test_noise() - test_moving_average() - test_rms() - test_net() - test_lr_schedulers() +def test_in_eval_mode(): + module = nn.Linear(3, 4) + module.train() + with in_eval_mode(module): + assert not module.training + assert module.training + + +def test_in_train_mode(): + module = nn.Linear(3, 4) + module.eval() + with in_train_mode(module): + assert module.training + assert not module.training \ No newline at end of file diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 4d990db4b..b790fddac 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -27,7 +27,7 @@ def test_episode( collector.reset(reset_stats=False) if test_fn: test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode, is_eval=True) + result = collector.collect(n_episode=n_episode, eval_mode=True) if reward_metric: # TODO: move into collector rew = reward_metric(result.returns) result.returns = rew diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py new file mode 100644 index 000000000..1676b524d --- /dev/null +++ b/tianshou/utils/torch_utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager + +from torch import nn + + +@contextmanager +def in_eval_mode(module: nn.Module) -> None: + """Temporarily switch to evaluation mode.""" + train = module.training + try: + module.eval() + yield + finally: + module.train(train) + + +@contextmanager +def in_train_mode(module: nn.Module) -> None: + """Temporarily switch to training mode.""" + train = module.training + try: + module.train() + yield + finally: + module.train(train) From 12d4262f80329720b6b495b89e95260eacaea931 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 14:58:58 +0200 Subject: [PATCH 06/32] Tests: removed all instances of `if __name__ == ...` in tests A test is not a script and should not be used as such Also marked pistonball test as skipped since it doesn't actually test anything --- test/base/test_action_space_sampling.py | 7 --- test/base/test_batch.py | 13 ---- test/base/test_buffer.py | 43 +------------ test/base/test_collector.py | 16 +---- test/base/test_env.py | 33 +--------- test/base/test_env_finite.py | 5 -- test/base/test_returns.py | 63 ------------------- test/continuous/test_ddpg.py | 13 ---- test/continuous/test_npg.py | 13 ---- test/continuous/test_ppo.py | 13 ---- test/continuous/test_redq.py | 13 ---- test/continuous/test_sac_with_il.py | 4 -- test/continuous/test_td3.py | 13 ---- test/continuous/test_trpo.py | 13 ---- test/discrete/test_a2c_with_il.py | 23 ------- test/discrete/test_bdq.py | 21 +------ test/discrete/test_c51.py | 15 ----- test/discrete/test_dqn.py | 15 ----- test/discrete/test_drqn.py | 14 ----- test/discrete/test_fqf.py | 15 ----- test/discrete/test_iqn.py | 15 ----- test/discrete/test_pg.py | 14 ----- test/discrete/test_ppo.py | 14 ----- test/discrete/test_qrdqn.py | 15 ----- test/discrete/test_rainbow.py | 15 ----- test/discrete/test_sac.py | 14 ----- test/modelbased/test_dqn_icm.py | 14 ----- test/modelbased/test_ppo_icm.py | 13 ---- test/modelbased/test_psrl.py | 16 +---- test/offline/test_bcq.py | 19 +----- test/offline/test_cql.py | 21 +------ test/offline/test_discrete_bcq.py | 20 +----- test/offline/test_discrete_cql.py | 20 +----- test/offline/test_discrete_crr.py | 19 +----- test/offline/test_gail.py | 19 +----- test/offline/test_td3_bc.py | 19 +----- test/pettingzoo/test_pistonball.py | 14 +---- test/pettingzoo/test_pistonball_continuous.py | 10 --- test/pettingzoo/test_tic_tac_toe.py | 10 --- 39 files changed, 15 insertions(+), 651 deletions(-) diff --git a/test/base/test_action_space_sampling.py b/test/base/test_action_space_sampling.py index fbbf25c04..e3d82767a 100644 --- a/test/base/test_action_space_sampling.py +++ b/test/base/test_action_space_sampling.py @@ -48,10 +48,3 @@ def test_shmem_vec_env_action_space() -> None: action2 = [ac_space.sample() for ac_space in envs.action_space] assert action1 == action2 - - -if __name__ == "__main__": - test_gym_env_action_space() - test_dummy_vec_env_action_space() - test_subproc_vec_env_action_space() - test_shmem_vec_env_action_space() diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 5e90dfb66..86d4af500 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -749,16 +749,3 @@ def test_to_torch_() -> None: assert id_batch == id(batch) assert isinstance(batch.b, torch.Tensor) assert isinstance(batch.c.d, torch.Tensor) - - -if __name__ == "__main__": - test_batch() - test_batch_over_batch() - test_batch_over_batch_to_torch() - test_utils_to_torch_numpy() - test_batch_pickle() - test_batch_from_to_numpy_without_copy() - test_batch_standard_compatibility() - test_batch_cat_and_stack() - test_batch_copy() - test_batch_empty() diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 5488ff365..40f450c8f 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,7 +1,7 @@ import os import pickle import tempfile -from timeit import timeit +from test.base.env import MoveToRightEnv, MyGoalEnv import h5py import numpy as np @@ -22,11 +22,6 @@ ) from tianshou.data.utils.converter import to_hdf5 -if __name__ == "__main__": - from env import MoveToRightEnv, MyGoalEnv -else: # pytest - from test.base.env import MoveToRightEnv, MyGoalEnv - def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: env = MoveToRightEnv(size) @@ -607,24 +602,6 @@ def test_segtree() -> None: index = tree.get_prefix_sum_idx(scalar) assert naive[:index].sum() <= scalar <= naive[: index + 1].sum() - # profile - if __name__ == "__main__": - size = 100000 - bsz = 64 - naive = np.random.rand(size) - tree = SegmentTree(size) - tree[np.arange(size)] = naive - - def sample_npbuf() -> np.ndarray: - return np.random.choice(size, bsz, p=naive / naive.sum()) - - def sample_tree() -> int | np.ndarray: - scalar = np.random.rand(bsz) * tree.reduce() - return tree.get_prefix_sum_idx(scalar) - - print("npbuf", timeit(sample_npbuf, setup=sample_npbuf, number=1000)) - print("tree", timeit(sample_tree, setup=sample_tree, number=1000)) - def test_pickle() -> None: size = 100 @@ -1401,21 +1378,3 @@ def test_custom_key() -> None: ): assert batch.__dict__[key].is_empty() assert sampled_batch.__dict__[key].is_empty() - - -if __name__ == "__main__": - test_replaybuffer() - test_ignore_obs_next() - test_stack() - test_segtree() - test_priortized_replaybuffer() - test_update() - test_pickle() - test_hdf5() - test_replaybuffermanager() - test_cachedbuffer() - test_multibuf_stack() - test_multibuf_hdf5() - test_from_data() - test_herreplaybuffer() - test_custom_key() diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6baa6abf3..08483bb9b 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from test.base.env import MoveToRightEnv, NXEnv from typing import Any import gymnasium as gym @@ -25,11 +26,6 @@ except ImportError: envpool = None -if __name__ == "__main__": - from env import MoveToRightEnv, NXEnv -else: # pytest - from test.base.env import MoveToRightEnv, NXEnv - class MaxActionPolicy(BasePolicy): def __init__( @@ -963,13 +959,3 @@ def test_async_collector_with_vector_env() -> None: assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens) c2r = c1.collect(n_step=20) assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens) - - -if __name__ == "__main__": - test_collector() - test_collector_with_dict_state() - test_collector_with_multi_agent() - test_collector_with_atari_setting() - test_collector_envpool_gym_reset_return_info() - test_collector_with_vector_env() - test_async_collector_with_vector_env() diff --git a/test/base/test_env.py b/test/base/test_env.py index a476ec5a9..1a33e861c 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,6 +1,7 @@ import sys import time from collections.abc import Callable +from test.base.env import MoveToRightEnv, NXEnv from typing import Any, Literal import gymnasium as gym @@ -22,11 +23,6 @@ from tianshou.env.venvs import BaseVectorEnv from tianshou.utils import RunningMeanStd -if __name__ == "__main__": - from env import MoveToRightEnv, NXEnv -else: # pytest - from test.base.env import MoveToRightEnv, NXEnv - try: import envpool except ImportError: @@ -190,19 +186,6 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: for info in infos: assert recurse_comp(infos[0], info) - if __name__ == "__main__": - t = [0.0] * len(venv) - for i, e in enumerate(venv): - t[i] = time.time() - e.reset() - for a in action_list: - done = e.step(np.array([a] * num))[2] - if sum(done) > 0: - e.reset(np.where(done)[0]) - t[i] = time.time() - t[i] - for i, v in enumerate(venv): - print(f"{type(v)}: {t[i]:.6f}s") - def assert_get(v: BaseVectorEnv, expected: list) -> None: assert v.get_env_attr("size") == expected assert v.get_env_attr("size", id=0) == [expected[0]] @@ -437,17 +420,3 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None: for _, v in _info.items(): if not isinstance(v, dict): assert v.shape[0] == num_envs - - -if __name__ == "__main__": - test_venv_norm_obs() - test_venv_wrapper_gym() - test_venv_wrapper_envpool() - test_venv_wrapper_envpool_gym_reset_return_info() - test_env_obs_dtype() - test_vecenv() - test_attr_unwrapped() - test_async_env() - test_async_check_id() - test_env_reset_optional_kwargs() - test_gym_wrappers() diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 657100554..ce8a93640 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -268,8 +268,3 @@ def test_finite_subproc_vector_env() -> None: test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() - - -if __name__ == "__main__": - test_finite_dummy_vector_env() - test_finite_subproc_vector_env() diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 23f50fb22..ab4430b85 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,10 +1,7 @@ -from timeit import timeit - import numpy as np import torch from tianshou.data import Batch, ReplayBuffer, to_numpy -from tianshou.data.types import BatchWithReturnsProtocol from tianshou.policy import BasePolicy @@ -142,28 +139,6 @@ def test_episodic_returns(size: int = 2560) -> None: ) assert np.allclose(returns, ground_truth) - if __name__ == "__main__": - buf = ReplayBuffer(size) - batch = Batch( - terminated=np.random.randint(100, size=size) == 0, - truncated=np.zeros(size), - rew=np.random.random(size), - ) - for b in iter(batch): - b.obs = b.act = 1 - buf.add(b) - indices = buf.sample_indices(0) - - def vanilla() -> Batch: - return compute_episodic_return_base(batch, gamma=0.1) - - def optimized() -> tuple[np.ndarray, np.ndarray]: - return fn(batch, buf, indices, gamma=0.1, gae_lambda=1.0) - - cnt = 3000 - print("GAE vanilla", timeit(vanilla, setup=vanilla, number=cnt)) - print("GAE optim ", timeit(optimized, setup=optimized, number=cnt)) - def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: # return the next reward @@ -356,41 +331,3 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: ).pop("returns"), ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) - - if __name__ == "__main__": - buf = ReplayBuffer(size) - for i in range(int(size * 1.5)): - buf.add( - Batch( - obs=0, - act=0, - rew=i + 1, - terminated=np.random.randint(3) == 0, - truncated=i % 33 == 0, - info={}, - ), - ) - batch, indices = buf.sample(256) - - def vanilla() -> np.ndarray: - return compute_nstep_return_base(3, 0.1, buf, indices) - - def optimized() -> BatchWithReturnsProtocol: - return BasePolicy.compute_nstep_return( - batch, - buf, - indices, - target_q_fn, - gamma=0.1, - n_step=3, - ) - - cnt = 3000 - print("nstep vanilla", timeit(vanilla, setup=vanilla, number=cnt)) - print("nstep optim ", timeit(optimized, setup=optimized, number=cnt)) - - -if __name__ == "__main__": - test_nstep_returns() - test_nstep_returns_with_timelimit() - test_episodic_returns() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 4b776ce8e..1aedadabf 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -133,15 +132,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_ddpg() diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index c60aa2ba0..98803a9d2 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -155,15 +154,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_npg() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 33bcc55f3..15d834096 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -191,19 +190,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: assert stop_fn(epoch_stat.info_stat.best_reward) - if __name__ == "__main__": - pprint.pprint(epoch_stat) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_ppo(args) - - -if __name__ == "__main__": - test_ppo() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index a97b4432a..f627f7e4f 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -164,15 +163,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_redq() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 0bf80e50c..77a403359 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -204,7 +204,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - -if __name__ == "__main__": - test_sac_with_il() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d5efc5746..21a2cf40d 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -155,16 +155,3 @@ def stop_fn(mean_rewards: float) -> bool: # print(info) assert stop_fn(epoch_stat.info_stat.best_reward) - - if __name__ == "__main__": - pprint.pprint(epoch_stat.info_stat) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_td3() diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index c0debf71c..8841891bf 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -155,15 +154,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_trpo() diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index a2b9fb419..2fd41aff8 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -147,15 +146,6 @@ def stop_fn(mean_rewards: float) -> bool: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal @@ -200,16 +190,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(il_policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_a2c_with_il() diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 6c93b0230..91c66bac0 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,5 +1,4 @@ import argparse -import pprint import gymnasium as gym import numpy as np @@ -129,7 +128,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OffpolicyTrainer( + OffpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, @@ -143,21 +142,3 @@ def stop_fn(mean_rewards: float) -> bool: test_fn=test_fn, stop_fn=stop_fn, ).run() - - # assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - policy.set_eps(args.eps_test) - test_envs.seed(args.seed) - test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - is_eval=True, - ) - collector_stats.pprint_asdict() - - -if __name__ == "__main__": - test_bdq(get_args()) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 3abd45ba6..8b34ddb4b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,7 +1,6 @@ import argparse import os import pickle -import pprint import gymnasium as gym import numpy as np @@ -202,16 +201,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_c51_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True @@ -223,7 +212,3 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_c51(args) - - -if __name__ == "__main__": - test_c51(get_args()) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index cd8b3f50d..773004f2c 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -155,23 +154,9 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_pdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 test_dqn(args) - - -if __name__ == "__main__": - test_dqn(get_args()) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index bae341f38..193179097 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -131,16 +130,3 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_drqn(get_args()) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 399672ea4..743293be0 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -172,22 +171,8 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_pfqf(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_fqf(args) - - -if __name__ == "__main__": - test_fqf(get_args()) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index f11ea466a..f7ea67adb 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -168,22 +167,8 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_piqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_iqn(args) - - -if __name__ == "__main__": - test_iqn(get_args()) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 51142eb2e..60d0eb469 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -124,16 +123,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_pg() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 03f6f8df3..27fe6f517 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -151,16 +150,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_ppo() diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 6d39d9a0d..76d7d429d 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -157,22 +156,8 @@ def test_fn(epoch: int, env_step: int | None) -> None: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_qrdqn(args) - - -if __name__ == "__main__": - test_pqrdqn(get_args()) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 0dabd9e90..0a73d4b77 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,7 +1,6 @@ import argparse import os import pickle -import pprint import gymnasium as gym import numpy as np @@ -219,16 +218,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True @@ -240,7 +229,3 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_rainbow(args) - - -if __name__ == "__main__": - test_rainbow(get_args()) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index cdc53d749..f16e59daf 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -142,16 +141,3 @@ def stop_fn(mean_rewards: float) -> bool: test_in_train=False, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_discrete_sac() diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 01e6572a2..5ef0bba65 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -197,16 +196,3 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_dqn_icm(get_args()) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index fcb541e9d..77f9a40e1 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -189,15 +188,3 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_ppo() diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index ca7a38895..41849992d 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import numpy as np import pytest @@ -116,17 +115,4 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, test_in_train=False, ).run() - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - test_envs.seed(args.seed) - test_collector.reset() - stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) - stats.pprint_asdict() - elif env.spec.reward_threshold: - assert result.best_reward >= env.spec.reward_threshold - - -if __name__ == "__main__": - test_psrl() + assert result.best_reward >= env.spec.reward_threshold diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 660c607d2..e28368e11 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -2,7 +2,7 @@ import datetime import os import pickle -import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -19,11 +19,6 @@ from tianshou.utils.net.continuous import VAE, Critic, Perturbation from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -207,15 +202,3 @@ def watch() -> None: show_progress=args.show_progress, ).run() assert stop_fn(result.best_reward) - - # Let's watch its performance! - if __name__ == "__main__": - pprint.pprint(result) - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_bcq() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 53aaf1efa..41d67151a 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,6 +3,7 @@ import os import pickle import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -19,11 +20,6 @@ from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -205,18 +201,3 @@ def stop_fn(mean_rewards: float) -> bool: # print(info) assert stop_fn(epoch_stat.info_stat.best_reward) - - # Let's watch its performance! - if __name__ == "__main__": - pprint.pprint(epoch_stat.info_stat) - env = gym.make(args.task) - collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) - if collector_result.returns_stat and collector_result.lens_stat: - print( - f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", - ) - - -if __name__ == "__main__": - test_cql() diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index e151633a3..589809985 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -1,7 +1,7 @@ import argparse import os import pickle -import pprint +from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -17,11 +17,6 @@ from tianshou.utils.net.discrete import Actor from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_cartpole_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_cartpole_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -165,21 +160,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ).run() assert stop_fn(result.best_reward) - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: test_discrete_bcq() args.resume = True test_discrete_bcq(args) - - -if __name__ == "__main__": - test_discrete_bcq(get_args()) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 309f3b6e2..b62fdcc26 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -1,7 +1,7 @@ import argparse import os import pickle -import pprint +from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -16,11 +16,6 @@ from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_cartpole_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_cartpole_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -126,16 +121,3 @@ def stop_fn(mean_rewards: float) -> bool: ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_discrete_cql(get_args()) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index c3f254917..eee872556 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -1,7 +1,7 @@ import argparse import os import pickle -import pprint +from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -17,11 +17,6 @@ from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_cartpole_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_cartpole_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -130,15 +125,3 @@ def stop_fn(mean_rewards: float) -> bool: ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_discrete_crr(get_args()) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index df0f7bfb1..ea13f484c 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -1,7 +1,7 @@ import argparse import os import pickle -import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -18,11 +18,6 @@ from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -226,15 +221,3 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: save_checkpoint_fn=save_checkpoint_fn, ).run() assert stop_fn(result.best_reward) - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_gail() diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 43762ddfd..fa01444ab 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -2,7 +2,7 @@ import datetime import os import pickle -import pprint +from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -20,11 +20,6 @@ from tianshou.utils.net.continuous import Actor, Critic from tianshou.utils.space_info import SpaceInfo -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -193,15 +188,3 @@ def stop_fn(mean_rewards: float) -> bool: # print(info) assert stop_fn(epoch_stat.info_stat.best_reward) - - # Let's watch its performance! - if __name__ == "__main__": - pprint.pprint(epoch_stat.info_stat) - env = gym.make(args.task) - collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) - print(collector_stats) - - -if __name__ == "__main__": - test_td3_bc() diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 2432ca531..36f1d2d3a 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -1,22 +1,14 @@ import argparse -import pprint +import pytest from pistonball import get_args, train_agent, watch +@pytest.mark.skip(reason="Performance bound was never tested, no point in running this for now") def test_piston_ball(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return - result, agent = train_agent(args) + train_agent(args) # assert result.best_reward >= args.win_rate - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == "__main__": - test_piston_ball(get_args()) diff --git a/test/pettingzoo/test_pistonball_continuous.py b/test/pettingzoo/test_pistonball_continuous.py index bb2979bc6..b96a29c0c 100644 --- a/test/pettingzoo/test_pistonball_continuous.py +++ b/test/pettingzoo/test_pistonball_continuous.py @@ -1,5 +1,4 @@ import argparse -import pprint import pytest from pistonball_continuous import get_args, train_agent, watch @@ -13,12 +12,3 @@ def test_piston_ball_continuous(args: argparse.Namespace = get_args()) -> None: result, agent = train_agent(args) # assert result.best_reward >= 30.0 - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == "__main__": - test_piston_ball_continuous(get_args()) diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index 44aa86b9f..0f0c237c8 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -1,5 +1,4 @@ import argparse -import pprint from tic_tac_toe import get_args, train_agent, watch @@ -11,12 +10,3 @@ def test_tic_tac_toe(args: argparse.Namespace = get_args()) -> None: result, agent = train_agent(args) assert result.best_reward >= args.win_rate - - if __name__ == "__main__": - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == "__main__": - test_tic_tac_toe(get_args()) From 4b619c51baecf85a99f5ad199f402fe6e006db90 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 16:46:03 +0200 Subject: [PATCH 07/32] Collector: extracted interface BaseCollector, minor simplifications Renamed is_eval kwarg --- examples/atari/atari_c51.py | 4 +- examples/atari/atari_dqn.py | 4 +- examples/atari/atari_fqf.py | 4 +- examples/atari/atari_iqn.py | 4 +- examples/atari/atari_ppo.py | 4 +- examples/atari/atari_qrdqn.py | 4 +- examples/atari/atari_rainbow.py | 4 +- examples/atari/atari_sac.py | 4 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_bdq.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/box2d/mcc_sac.py | 2 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/fetch_her_ddpg.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_ddpg.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_redq.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_sac.py | 2 +- examples/mujoco/mujoco_td3.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/offline/atari_bcq.py | 2 +- examples/offline/atari_cql.py | 2 +- examples/offline/atari_crr.py | 2 +- examples/offline/atari_il.py | 2 +- examples/offline/d4rl_bcq.py | 4 +- examples/offline/d4rl_cql.py | 4 +- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 4 +- examples/vizdoom/vizdoom_c51.py | 4 +- examples/vizdoom/vizdoom_ppo.py | 4 +- test/offline/gather_cartpole_data.py | 2 +- test/offline/test_bcq.py | 2 +- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/data/collector.py | 464 ++++++++++++----------- tianshou/highlevel/experiment.py | 2 +- 41 files changed, 300 insertions(+), 272 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 16694af3d..98d1d0860 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -174,7 +174,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -184,7 +184,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index aed46a2a3..7d60654ac 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -216,7 +216,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -226,7 +226,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 92a614007..185cff145 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -187,7 +187,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -197,7 +197,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 3c4a695bc..5216d7c4e 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -184,7 +184,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -194,7 +194,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 3c2c68d9b..969d00aaf 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -240,7 +240,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -250,7 +250,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index c821fad40..5b0258108 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -178,7 +178,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -188,7 +188,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 0d14f5d9f..5bb69a38b 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -214,7 +214,7 @@ def watch() -> None: beta=args.beta, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -224,7 +224,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index a8d759ad8..be13884e9 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -227,7 +227,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -237,7 +237,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ed6a65ba4..ac2c800d6 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -150,7 +150,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index da425e623..7c89cd276 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -168,7 +168,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 8dfbed69f..214721356 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -212,7 +212,7 @@ def stop_fn(mean_rewards: float) -> bool: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 4bd289644..8280e236f 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -147,7 +147,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 257641187..97f714983 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -158,7 +158,7 @@ def stop_fn(mean_rewards: float) -> bool: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 3c1035927..e30b3f651 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -269,7 +269,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 01a89b610..8b3439a8f 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -243,7 +243,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 0195702a1..834e0ced5 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -224,7 +224,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 98852dc6b..066c7d4c9 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -173,7 +173,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index c2a193a22..40cc443df 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -221,7 +221,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 9e20c467c..eb6817ef9 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -229,7 +229,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 791005c73..9101f8634 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -201,7 +201,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 8912ee64a..cbd4d3e1d 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -201,7 +201,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 6c0bbdc95..8d9d61ec1 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -195,7 +195,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 8e0c31576..d1b5e6921 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -193,7 +193,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 96da0cfe6..b5a91c0b5 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -226,7 +226,7 @@ def save_best_fn(policy: BasePolicy) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 80687d163..d4e985ea8 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -191,7 +191,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 61071453f..69ae433ff 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -175,7 +175,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index ad5c2312f..4e0771c0a 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -191,7 +191,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 495b630b3..928112c77 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -148,7 +148,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 983424730..308a872eb 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -207,7 +207,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -233,7 +233,7 @@ def watch() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 4e6a127aa..9d95d29ca 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -345,7 +345,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -371,7 +371,7 @@ def watch() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index a8a6a66bb..d64cfe9da 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -143,7 +143,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -169,7 +169,7 @@ def watch() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 24138a4b8..5719eba16 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -192,7 +192,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -221,7 +221,7 @@ def watch() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 356605151..8a52e5f5d 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -180,7 +180,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -190,7 +190,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 907b0ebfe..adbfb0584 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -242,7 +242,7 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -252,7 +252,7 @@ def watch() -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 91ee284f9..7d6aba1b2 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -167,7 +167,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) collector.reset() - collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True) + collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index e28368e11..20e4dd6c6 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -185,7 +185,7 @@ def watch() -> None: torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) # trainer result = OfflineTrainer( diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 7fc913451..27392f740 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy, _, _ = get_agents(args) [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, is_eval=True) + result = collector.collect(n_episode=1, render=args.render, eval_mode=True) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 010047058..ed085225a 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non ) policy, _, _ = get_agents(args) collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) + collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index b63636ae7..fc46e2a8f 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -230,5 +230,5 @@ def watch( policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, is_eval=True) + result = collector.collect(n_episode=1, render=args.render, eval_mode=True) result.pprint_asdict() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1a60bd083..0b9fd6dd7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,5 +1,7 @@ +import logging import time import warnings +from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass from typing import Any, Self, TypeVar, cast @@ -7,11 +9,11 @@ import gymnasium as gym import numpy as np import torch +from overrides import override from tianshou.data import ( Batch, CachedReplayBuffer, - PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, SequenceSummaryStats, @@ -25,6 +27,9 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import in_eval_mode, in_train_mode + +log = logging.getLogger(__name__) @dataclass(kw_only=True) @@ -122,23 +127,12 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class Collector: - """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` - as the default buffer. - :param exploration_noise: determine whether the action needs to be modified - with the corresponding policy's exploration noise. If so, "policy. - exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action. Default to False. +class BaseCollector(ABC): + """Used to collect data from a vector environment into a buffer using a given policy. .. note:: - Please make sure the given environment has a time limitation if using n_episode + Please make sure the given environment has a time limitation if using `n_episode` collect option. .. note:: @@ -150,72 +144,70 @@ class Collector: def __init__( self, policy: BasePolicy, - env: gym.Env | BaseVectorEnv, + env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: - super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy - self.env = DummyVectorEnv([lambda: env]) - else: - self.env = env # type: ignore - self.env_num = len(self.env) - self.exploration_noise = exploration_noise - self.buffer = self._assign_buffer(buffer) + env = DummyVectorEnv([lambda: env]) # type: ignore + + if buffer is None: + buffer = VectorReplayBuffer(len(env), len(env)) + + self.buffer: ReplayBuffer = buffer self.policy = policy + self.env = cast(BaseVectorEnv, env) + self.exploration_noise = exploration_noise + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + self._action_space = self.env.action_space + self._is_closed = False - self._pre_collect_obs_RO: np.ndarray | None = None - self._pre_collect_info_R: np.ndarray | None = None - self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + self._validate_buffer() - self._is_closed = False - self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + def _validate_buffer(self) -> None: + buf = self.buffer + # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. + # We should probably rename the manager + if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.", + ) + if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.", + ) + # Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance + if not isinstance(buf, ReplayBufferManager): + if buf.maxsize == 0: + raise ValueError("Buffer maxsize should be greater than 0.") + if self.env_num > 1: + raise ValueError( + f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). " + f"Please use the corresponding VectorReplayBuffer instead.", + ) + + @property + def env_num(self) -> int: + return len(self.env) + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space def close(self) -> None: """Close the collector and the environment.""" self.env.close() - self._pre_collect_obs_RO = None - self._pre_collect_info_R = None self._is_closed = True - @property - def is_closed(self) -> bool: - """Return True if the collector is closed.""" - return self._is_closed - - def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: - """Check if the buffer matches the constraint.""" - if buffer is None: - buffer = VectorReplayBuffer(self.env_num, self.env_num) - elif isinstance(buffer, ReplayBufferManager): - assert buffer.buffer_num >= self.env_num - if isinstance(buffer, CachedReplayBuffer): - assert buffer.cached_buffer_num >= self.env_num - else: # ReplayBuffer or PrioritizedReplayBuffer - assert buffer.maxsize > 0 - if self.env_num > 1: - if isinstance(buffer, ReplayBuffer): - buffer_type = "ReplayBuffer" - vector_type = "VectorReplayBuffer" - if isinstance(buffer, PrioritizedReplayBuffer): - buffer_type = "PrioritizedReplayBuffer" - vector_type = "PrioritizedVectorReplayBuffer" - raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " - f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" - f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", - ) - return buffer - def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached @@ -224,12 +216,13 @@ def reset( :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) """ - self.reset_env(gym_reset_kwargs=gym_reset_kwargs) + obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: self.reset_buffer() if reset_stats: self.reset_stat() self._is_closed = False + return obs_NO, info_N def reset_stat(self) -> None: """Reset the statistic variables.""" @@ -242,18 +235,168 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environments and the initial obs, info, and hidden state of the collector.""" gym_reset_kwargs = gym_reset_kwargs or {} - self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs) + obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict - if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable] + if isinstance(info_N, dict): # type: ignore[unreachable] # this can happen if the env is an envpool env. Then the thing returned by reset is a dict # with array entries instead of an array of dicts # We use Batch to turn it into an array of dicts - self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] + info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable] + return obs_NO, info_N + @abstractmethod + def _collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + pass + + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + eval_mode: bool = False, + ) -> CollectStats: + """Collect a specified number of steps or episodes. + + To ensure an unbiased sampling result with the n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + + :param n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param random: whether to use random policy for collecting data. + :param render: the sleep time between rendering consecutive frames. + :param no_grad: whether to retain gradient in policy.forward(). + :param reset_before_collect: whether to reset the environment before collecting data. + (The collector needs the initial obs and info to function properly.) + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Only used if reset_before_collect is True. + :param eval_mode: whether to collect data in evaluation mode. Will + set the policy to training mode otherwise. + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: The collected stats + """ + # check that exactly one of n_step or n_episode is set and that the other is larger than 0 + self._validate_n_step_n_episode(n_episode, n_step) + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + policy_mode_context = in_eval_mode if eval_mode else in_train_mode + with policy_mode_context(self.policy): + return self._collect( + n_step=n_step, + n_episode=n_episode, + random=random, + render=render, + no_grad=no_grad, + gym_reset_kwargs=gym_reset_kwargs, + ) + + def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: + if not n_step and not n_episode: + raise ValueError( + f"Only one of n_step and n_episode should be set to a value larger than zero " + f"but got {n_step=}, {n_episode=}.", + ) + if n_step is None and n_episode is None: + raise ValueError( + "Exactly one of n_step and n_episode should be set but got None for both.", + ) + if n_step and n_step % self.env_num != 0: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}), " + "which may cause extra transitions being collected into the buffer.", + ) + if n_episode and self.env_num > n_episode: + warnings.warn( + f"{n_episode=} should be larger than {self.env_num=} to " + f"collect at least one trajectory in each environment.", + ) + + +class Collector(BaseCollector): + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # N - number of envs, always fixed and >= R. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + + # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy + # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on + # policy.deterministic_eval) + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + """ + super().__init__(policy, env, buffer, exploration_noise=exploration_noise) + self._pre_collect_obs_RO: np.ndarray | None = None + self._pre_collect_info_R: np.ndarray | None = None + self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + + self._is_closed = False + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + def close(self) -> None: + super().close() + self._pre_collect_obs_RO = None + self._pre_collect_info_R = None + + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environments and the initial obs, info, and hidden state of the collector.""" + obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + # We assume that R = N when reset is called. + # TODO: there is currently no mechanism that ensures this and it's a public method! + self._pre_collect_obs_RO = obs_NO + self._pre_collect_info_R = info_N self._pre_collect_hidden_state_RH = None + return obs_NO, info_N def _compute_action_policy_hidden( self, @@ -309,98 +452,30 @@ def _compute_action_policy_hidden( return act_RA, act_normalized_RA, policy_R, hidden_state_RH # TODO: reduce complexity, remove the noqa - def collect( + def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, no_grad: bool = True, - reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - is_eval: bool = False, ) -> CollectStats: - """Collect a specified number of steps or episodes. - - To ensure an unbiased sampling result with the n_episode option, this function will - first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` - episodes, they will be collected evenly from each env. - - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. - :param render: the sleep time between rendering consecutive frames. - :param no_grad: whether to retain gradient in policy.forward(). - :param reset_before_collect: whether to reset the environment before collecting data. - (The collector needs the initial obs and info to function properly.) - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Only used if reset_before_collect is True. - :param is_eval: whether to collect data in evaluation mode. - - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. - - :return: The collected stats - """ - # NAMING CONVENTION (mostly suffixes): - # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, - # the corresponding env is either reset or removed from the ready envs. - # R - number ready env ids. Note that this might change when envs get idle. - # This can only happen in n_episode case, see explanation in the corresponding block. - # For n_step, we always use all envs to collect the data, while for n_episode, - # R will be at most n_episode at the beginning, but can decrease during the collection. - # O - dimension(s) of observations - # A - dimension(s) of actions - # H - dimension(s) of hidden state - # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. - # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. - # Only used in n_episode case. Then, R becomes R-S. - - # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy - # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on - # policy.deterministic_eval) - self.policy.eval() - pre_collect_is_eval = self.policy.is_eval - self.policy.is_eval = is_eval - - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} + # TODO: can't do it init since AsyncCollector is currently a subclass of Collector + if self.env.is_async: + raise ValueError( + f"Please use {AsyncCollector.__name__} for asynchronous environments. " + f"Env class: {self.env.__class__.__name__}.", + ) - # Input validation - assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: - assert n_episode is None, ( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}." - ) - assert n_step > 0 - if n_step % self.env_num != 0: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}), " - "which may cause extra transitions being collected into the buffer.", - ) ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: - assert n_episode > 0 - if self.env_num > n_episode: - warnings.warn( - f"{n_episode=} should be larger than {self.env_num=} to " - f"collect at least one trajectory in each environment.", - ) ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) - - start_time = time.time() - if reset_before_collect: - self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + use_grad = not no_grad + start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( "Initial obs and info should not be None. " @@ -491,7 +566,8 @@ def collect( step_count += len(ready_env_ids_R) # preparing for the next iteration - # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + # obs_next, info and hidden_state will be modified inplace in the code below, + # so we copy to not affect the data in the buffer last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) last_hidden_state_RH = copy(hidden_state_RH) @@ -509,6 +585,7 @@ def collect( # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, @@ -577,9 +654,6 @@ def collect( # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? - # set the policy back to pre collect mode - self.policy.is_eval = pre_collect_is_eval - return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), @@ -608,8 +682,7 @@ def _reset_hidden_state_based_on_type( class AsyncCollector(Collector): """Async Collector handles async vector environment. - The arguments are exactly the same as :class:`~tianshou.data.Collector`, please - refer to :class:`~tianshou.data.Collector` for more detailed explanation. + Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation. """ def __init__( @@ -619,6 +692,12 @@ def __init__( buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: + if not env.is_async: + # TODO: raise an exception? + log.error( + f"Please use {Collector.__name__} if not using async venv. " + f"Env class: {env.__class__.__name__}", + ) # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( @@ -644,7 +723,7 @@ def reset( reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached @@ -654,7 +733,7 @@ def reset( reset function. Defaults to None (extra keyword arguments) """ # This sets the _pre_collect attrs - super().reset( + result = super().reset( reset_buffer=reset_buffer, reset_stats=reset_stats, gym_reset_kwargs=gym_reset_kwargs, @@ -667,78 +746,29 @@ def reset( self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) self._current_action_in_all_envs_EA = np.empty(self.env_num) self._current_policy_in_all_envs_E = None + return result - def collect( + @override + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + # we need to step through the envs and wait until they are ready to be able to interact with them + if self.env.waiting_id: + self.env.step(None, id=self.env.waiting_id) + return super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + + @override + def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, no_grad: bool = True, - reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - is_eval: bool = False, ) -> CollectStats: - """Collect a specified number of steps or episodes with async env setting. - - This function does not collect an exact number of transitions specified by n_step or - n_episode. Instead, to support the asynchronous setting, it may collect more transitions - than requested by n_step or n_episode and save them into the buffer. - - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy_R for collecting data. Default - to False. - :param render: the sleep time between rendering consecutive frames. - Default to None (no rendering). - :param no_grad: whether to retain gradient in policy_R.forward(). Default to - True (no gradient retaining). - :param reset_before_collect: whether to reset the environment before - collecting data. It has only an effect if n_episode is not None, i.e. - if one wants to collect a fixed number of episodes. - (The collector needs the initial obs and info to function properly.) - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - :param is_eval: whether to collect data in evaluation mode. - - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. - - :return: A dataclass object - """ - # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy - # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on - # policy.deterministic_eval) - self.policy.eval() - pre_collect_is_eval = self.policy.is_eval - self.policy.is_eval = is_eval - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} - - # collect at least n_step or n_episode - if n_step is not None: - assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." - ) - assert n_step > 0 - elif n_episode is not None: - assert n_episode > 0 - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) - - if reset_before_collect: - # first we need to step all envs to be able to interact with them - if self.env.waiting_id: - self.env.step(None, id=self.env.waiting_id) - self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - start_time = time.time() step_count = 0 @@ -868,12 +898,12 @@ def collect( num_collected_episodes += num_episodes_done_this_iter # preparing for the next iteration - # todo do we need the copy stuff (tests pass also without) # todo seem we can get rid of this last_sth stuff altogether last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) - last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] - + last_hidden_state_RH = copy( + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index] + ) if num_episodes_done_this_iter: env_ind_local_D = np.where(done_R)[0] env_ind_global_D = ready_env_ids_R[env_ind_local_D] @@ -883,6 +913,7 @@ def collect( # now we copy obs_next_RO to obs, but since there might be # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, @@ -923,9 +954,6 @@ def collect( # persist for future collect iterations self._ready_env_ids_R = ready_env_ids_R - # set the policy back to pre collect mode - self.policy.is_eval = pre_collect_is_eval - return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 71b8159a0..df87aca34 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -337,7 +337,7 @@ def _watch_agent( ) -> None: collector = Collector(policy, env) collector.reset() - result = collector.collect(n_episode=num_episodes, render=render, is_eval=True) + result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( From 69f07a8f12bce7003aad1e897476f422763da247 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 17:37:12 +0200 Subject: [PATCH 08/32] Tests: fixed typing issues by declaring union types and no longer reusing var names --- test/base/test_buffer.py | 50 ++++++++++++++++--------------- test/base/test_collector.py | 8 ++--- test/base/test_utils.py | 6 ++-- test/offline/test_discrete_bcq.py | 3 +- test/offline/test_discrete_cql.py | 3 +- test/offline/test_discrete_crr.py | 3 +- tianshou/utils/torch_utils.py | 5 ++-- 7 files changed, 42 insertions(+), 36 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 40f450c8f..1b3593db3 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -2,6 +2,7 @@ import pickle import tempfile from test.base.env import MoveToRightEnv, MyGoalEnv +from typing import cast import h5py import numpy as np @@ -381,25 +382,25 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf[tmp_indices].obs - obs_next = buf[tmp_indices].obs_next - rew = buf[tmp_indices].rew - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf[tmp_indices].obs) + obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) + rew_in_buf = buf[tmp_indices].rew + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g[0]) assert np.all(g_next == g_next[0]) - assert np.all(rew == (ag_next == g).astype(np.float32)) + assert np.all(rew_in_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf.next(tmp_indices) # Check that goals are correctly restored buf._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf[tmp_indices].obs - obs_next = buf[tmp_indices].obs_next - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf[tmp_indices].obs) + obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -411,24 +412,24 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf2[tmp_indices].obs - obs_next = buf2[tmp_indices].obs_next - rew = buf2[tmp_indices].rew - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf2[tmp_indices].obs) + obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) + rew_buf = buf2[tmp_indices].rew + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g_next) - assert np.all(rew == (ag_next == g).astype(np.float32)) + assert np.all(rew_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf2.next(tmp_indices) # Check that goals are correctly restored buf2._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf2[tmp_indices].obs - obs_next = buf2[tmp_indices].obs_next - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf2[tmp_indices].obs) + obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -442,7 +443,6 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf._index = 5 # shifted start index buf.future_p = 1 - action_list = [1] * 10 for ep_len in [5, 10]: obs, _ = env.reset() for i in range(ep_len): @@ -1030,6 +1030,7 @@ def test_multibuf_stack() -> None: size, ) obs, info = env.reset(options={"state": 1}) + obs = cast(np.ndarray, obs) for i in range(18): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated @@ -1057,7 +1058,8 @@ def test_multibuf_stack() -> None: assert np.all(buf4.truncated == buf5.truncated) obs = obs_next if done: - obs, info = env.reset(options={"state": 1}) + # obs is an array, but the env is malformed, so we can't properly type it + obs, info = env.reset(options={"state": 1}) # type: ignore[assignment] # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 08483bb9b..d03a54df7 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -218,11 +218,11 @@ def test_collector() -> None: c_dummy_venv_4_envs.collect(n_episode=4, random=True) # test corner case - with pytest.raises(TypeError): + with pytest.raises(ValueError): Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: @@ -260,7 +260,7 @@ def test_collect_without_argument_gives_error( async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens - with pytest.raises(TypeError): + with pytest.raises(ValueError): c1.collect() def test_collect_one_episode_async( diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 1d8e37c58..ac3b2fa4d 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -134,7 +134,7 @@ def test_lr_schedulers() -> None: ) -def test_in_eval_mode(): +def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() with in_eval_mode(module): @@ -142,9 +142,9 @@ def test_in_eval_mode(): assert module.training -def test_in_train_mode(): +def test_in_train_mode() -> None: module = nn.Linear(3, 4) module.eval() with in_train_mode(module): assert module.training - assert not module.training \ No newline at end of file + assert not module.training diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 589809985..6e8e8784a 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -8,7 +8,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteBCQPolicy from tianshou.trainer import OfflineTrainer @@ -96,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: imitation_logits_penalty=args.imitation_logits_penalty, ) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index b62fdcc26..f2a60e00c 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -8,7 +8,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCQLPolicy from tianshou.trainer import OfflineTrainer @@ -85,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: min_q_weight=args.min_q_weight, ).to(args.device) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index eee872556..bc54dd9d0 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -8,7 +8,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCRRPolicy from tianshou.trainer import OfflineTrainer @@ -89,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 1676b524d..2fb70dad2 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -1,10 +1,11 @@ +from collections.abc import Iterator from contextlib import contextmanager from torch import nn @contextmanager -def in_eval_mode(module: nn.Module) -> None: +def in_eval_mode(module: nn.Module) -> Iterator[None]: """Temporarily switch to evaluation mode.""" train = module.training try: @@ -15,7 +16,7 @@ def in_eval_mode(module: nn.Module) -> None: @contextmanager -def in_train_mode(module: nn.Module) -> None: +def in_train_mode(module: nn.Module) -> Iterator[None]: """Temporarily switch to training mode.""" train = module.training try: From 2eaf1f37c216553b1347240e3a3fe01e1f9d0623 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 17:53:27 +0200 Subject: [PATCH 09/32] Use the new BaseCollector interface for annotations --- tianshou/data/__init__.py | 3 ++- tianshou/highlevel/agent.py | 3 ++- tianshou/highlevel/world.py | 6 +++--- tianshou/trainer/base.py | 7 +++---- tianshou/trainer/utils.py | 8 ++++---- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 623079890..d8c5410ea 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -24,7 +24,7 @@ SequenceSummaryStats, TimingStats, ) -from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase +from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector __all__ = [ "Batch", @@ -50,4 +50,5 @@ "InfoStats", "SequenceSummaryStats", "TimingStats", + "BaseCollector", ] diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f71a7f981..c1313262e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -6,6 +6,7 @@ import gymnasium from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.collector import BaseCollector from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ( @@ -94,7 +95,7 @@ def create_train_test_collector( policy: BasePolicy, envs: Environments, reset_collectors: bool = True, - ) -> tuple[Collector, Collector]: + ) -> tuple[BaseCollector, BaseCollector]: """:param policy: :param envs: :param reset_collectors: Whether to reset the collectors before returning them. diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 1a8d64872..c32ef9cbc 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from tianshou.data import Collector + from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import BasePolicy @@ -16,8 +16,8 @@ class World: envs: "Environments" policy: "BasePolicy" - train_collector: "Collector" - test_collector: "Collector" + train_collector: "BaseCollector" + test_collector: "BaseCollector" logger: "TLogger" persist_directory: str restore_directory: str | None diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 3c8a27e21..f657f633a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -10,14 +10,13 @@ from tianshou.data import ( AsyncCollector, - Collector, CollectStats, EpochStats, InfoStats, ReplayBuffer, SequenceSummaryStats, ) -from tianshou.data.collector import CollectStatsBase +from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.policy import BasePolicy from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode @@ -152,8 +151,8 @@ def __init__( policy: BasePolicy, max_epoch: int, batch_size: int | None, - train_collector: Collector | None = None, - test_collector: Collector | None = None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index b790fddac..0c2bf1896 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -5,17 +5,17 @@ import numpy as np from tianshou.data import ( - Collector, CollectStats, InfoStats, SequenceSummaryStats, TimingStats, ) +from tianshou.data.collector import BaseCollector from tianshou.utils import BaseLogger def test_episode( - collector: Collector, + collector: BaseCollector, test_fn: Callable[[int, int | None], None] | None, epoch: int, n_episode: int, @@ -44,8 +44,8 @@ def gather_info( gradient_step: int, best_reward: float, best_reward_std: float, - train_collector: Collector | None = None, - test_collector: Collector | None = None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, ) -> InfoStats: """A simple wrapper of gathering information from collectors. From c28508b3be6bac7881f9769b167c31f58d2edec6 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 17:53:34 +0200 Subject: [PATCH 10/32] Changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ce216139..2b4cc551d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ - New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). Launchers for parallelization currently in alpha state. #1074 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 +- Base class for collectors: `BaseCollector` #1122 +- Collectors can now explicitly specify whether to use the policy in training or evaluation mode. #1122 +- New util context managers `in_eval_mode` and `in_train_mode` for torch modules. #1122 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 @@ -35,6 +38,8 @@ instead of just `nn.Module`. #1032 - tests and examples are covered by `mypy`. #1077 - `NetBase` is more used, stricter typing by making it generic. #1077 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 +- Removed all `if __name__ == "__main__":` blocks from tests. #1123 +- Improved typing issues in tests with buffer and collector. #1123 ### Breaking Changes From 6aa33b1bfe4fc99b1c623f86a12d2ef0ebf1320b Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 17:54:14 +0200 Subject: [PATCH 11/32] Formatting --- tianshou/data/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index d8c5410ea..c84c2ec7d 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -24,7 +24,13 @@ SequenceSummaryStats, TimingStats, ) -from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector +from tianshou.data.collector import ( + Collector, + AsyncCollector, + CollectStats, + CollectStatsBase, + BaseCollector, +) __all__ = [ "Batch", From e2e8a699eaf5de4d96b5b1b0866192e2c9a2641a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 18:11:23 +0200 Subject: [PATCH 12/32] Changelog [skip-ci] --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b4cc551d..9a8be12c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Launchers for parallelization currently in alpha state. #1074 - Base class for collectors: `BaseCollector` #1122 - Collectors can now explicitly specify whether to use the policy in training or evaluation mode. #1122 - New util context managers `in_eval_mode` and `in_train_mode` for torch modules. #1122 +- `reset` of `Collectors` now returns `obs` and `info`. #1122 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 From 45922712d94f2a39ce1e1becf90728e10b688d13 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 18:14:20 +0200 Subject: [PATCH 13/32] Dosctring add return [skip-ci] --- tianshou/data/collector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0b9fd6dd7..10cf663d2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -215,6 +215,7 @@ def reset( :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) + :return: The initial observation and info from the environment. """ obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: @@ -731,6 +732,7 @@ def reset( :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) + :return: The initial observation and info from the environment. """ # This sets the _pre_collect attrs result = super().reset( From a2b9d7c7d839cb5172a58a9e91a040da8f4d4dc0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 18:31:02 +0200 Subject: [PATCH 14/32] Changelog [skip-ci] --- CHANGELOG.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a8be12c2..9dd845951 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,8 +39,8 @@ instead of just `nn.Module`. #1032 - tests and examples are covered by `mypy`. #1077 - `NetBase` is more used, stricter typing by making it generic. #1077 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 -- Removed all `if __name__ == "__main__":` blocks from tests. #1123 -- Improved typing issues in tests with buffer and collector. #1123 +- Removed all `if __name__ == "__main__":` blocks from tests. #1122 +- Improved typing issues in tests with buffer and collector. #1122 ### Breaking Changes @@ -57,6 +57,7 @@ continuous and discrete cases. #1032 - `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074 - `BaseLogger.prepare_dict_for_logging` is now abstract. #1074 - Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074 +- Removed deprecations of `0.5.1` (will likely not affect anyone) and the unused `warnings` module. #1122 ### Tests From 4f16494609a33de67bb6454aaf3e64e523d0de60 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 11:51:08 +0200 Subject: [PATCH 15/32] Set torch train mode in BasePolicy.update instead of in each .learn implementation, as this is less prone to errors --- tianshou/policy/base.py | 4 +++- tianshou/policy/modelfree/a2c.py | 3 --- tianshou/policy/modelfree/bdq.py | 3 --- tianshou/policy/modelfree/c51.py | 2 -- tianshou/policy/modelfree/discrete_sac.py | 3 --- tianshou/policy/modelfree/dqn.py | 2 -- tianshou/policy/modelfree/fqf.py | 2 -- tianshou/policy/modelfree/iqn.py | 2 -- tianshou/policy/modelfree/pg.py | 3 --- tianshou/policy/modelfree/ppo.py | 3 --- tianshou/policy/modelfree/qrdqn.py | 2 -- tianshou/policy/modelfree/sac.py | 3 --- 12 files changed, 3 insertions(+), 29 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 450bf5228..e566d0c33 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,6 +25,7 @@ ) from tianshou.utils import MultipleLRSchedulers from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import in_train_mode logger = logging.getLogger(__name__) @@ -513,7 +514,8 @@ def update( batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - training_stat = self.learn(batch, **kwargs) + with in_train_mode(self): + training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f055c4122..d41ccb463 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -165,9 +165,6 @@ def learn( # type: ignore *args: Any, **kwargs: Any, ) -> TA2CTrainingStats: - # set policy in train mode - self.train() - losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index 80c17bef7..d7196a92b 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -163,9 +163,6 @@ def forward( return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: - # set policy in train mode - self.train() - if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index d406dda56..5bfdba0c1 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -117,8 +117,6 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: return target_dist.sum(-1) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d236e442e..7e731b1c3 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -124,9 +124,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q.sum(dim=-1) + self.alpha * dist.entropy() def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore - # set policy in train mode - self.train() - weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index d2e7910bd..e0ada0733 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -210,8 +210,6 @@ def forward( return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index c4a1a2d01..9c87f9cac 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -153,8 +153,6 @@ def forward( # type: ignore return cast(FQFBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index f868ce41e..75d76a2dd 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -131,8 +131,6 @@ def forward( return cast(QuantileRegressionBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index b86540da4..4792db8f5 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -211,9 +211,6 @@ def learn( # type: ignore *args: Any, **kwargs: Any, ) -> TPGTrainingStats: - # set policy in train mode - self.train() - losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 298711475..196cd72e4 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -151,9 +151,6 @@ def learn( # type: ignore *args: Any, **kwargs: Any, ) -> TPPOTrainingStats: - # set policy in train mode - self.train() - losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 9f3e1626c..71c36de0c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -105,8 +105,6 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc return super().compute_q_value(logits.mean(2), mask) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index f9ddf7cbc..3dbea7508 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -213,9 +213,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: ) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore - # set policy in train mode - self.train() - # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) From ca4dad113980e96c8632b1cce8153f874c8950e4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 18:06:01 +0200 Subject: [PATCH 16/32] BaseTrainer: Refactoring New method training_step, which * collects training data (method _collect_training_data) * performs "test in train" (method _test_in_train) * performs policy update The old method named train_step performed only the first two points and was now split into two separate methods --- tianshou/trainer/base.py | 90 +++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 33 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index f657f633a..825c80ceb 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -4,6 +4,7 @@ from collections import defaultdict, deque from collections.abc import Callable from dataclasses import asdict +from typing import Optional, Tuple import numpy as np import tqdm @@ -303,8 +304,10 @@ def __next__(self) -> EpochStats: with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - if self.train_collector is not None: - train_stat, self.stop_fn_flag = self.train_step() + + train_stat, update_stat, self.stop_fn_flag = self.training_step() + + if isinstance(train_stat, CollectStats): pbar_data_dict = { "env_step": str(self.env_step), "rew": f"{self.last_rew:.2f}", @@ -313,23 +316,17 @@ def __next__(self) -> EpochStats: "n/st": str(train_stat.n_collected_steps), } t.update(train_stat.n_collected_steps) - if self.stop_fn_flag: - t.set_postfix(**pbar_data_dict) - break else: pbar_data_dict = {} - assert self.buffer, "No train_collector or buffer specified" - train_stat = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) t.update() - update_stat = self.policy_update_fn(train_stat) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict["gradient_step"] = str(self._gradient_step) - t.set_postfix(**pbar_data_dict) + if self.stop_fn_flag: + break + if t.n <= t.total and not self.stop_fn_flag: t.update() @@ -410,45 +407,71 @@ def test_step(self) -> tuple[CollectStats, bool]: return test_stat, stop_fn_flag - def train_step(self) -> tuple[CollectStats, bool]: - """Perform one training step. + def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: + should_stop_training = False - If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. - Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return - on it. - Finally, if the latter is also True, will set should_stop_training to True. + if self.train_collector is not None: + collect_stats = self._collect_training_data() + should_stop_training = self._test_in_train(collect_stats) + else: + collect_stats = CollectStatsBase( + n_collected_episodes=len(self.buffer), + ) + + if not should_stop_training: + training_stats = self.policy_update_fn(collect_stats) + else: + training_stats = None + + return collect_stats, training_stats, should_stop_training - :return: A tuple of the training stats and a boolean indicating whether to stop training. + def _collect_training_data(self) -> CollectStats: + """Performs training data collection + + :return: the data collection stats """ assert self.episode_per_test is not None assert self.train_collector is not None - should_stop_training = False if self.train_fn: self.train_fn(self.epoch, self.env_step) - result = self.train_collector.collect( + collect_stats = self.train_collector.collect( n_step=self.step_per_collect, n_episode=self.episode_per_collect, ) - self.env_step += result.n_collected_steps + self.env_step += collect_stats.n_collected_steps - if result.n_collected_episodes > 0: - assert result.returns_stat is not None # for mypy - assert result.lens_stat is not None # for mypy - self.last_rew = result.returns_stat.mean - self.last_len = result.lens_stat.mean + if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None # for mypy + assert collect_stats.lens_stat is not None # for mypy + self.last_rew = collect_stats.returns_stat.mean + self.last_len = collect_stats.lens_stat.mean if self.reward_metric: # TODO: move inside collector - rew = self.reward_metric(result.returns) - result.returns = rew - result.returns_stat = SequenceSummaryStats.from_sequence(rew) + rew = self.reward_metric(collect_stats.returns) + collect_stats.returns = rew + collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) - self.logger.log_train_data(asdict(result), self.env_step) + self.logger.log_train_data(asdict(collect_stats), self.env_step) + + return collect_stats + + def _test_in_train(self, collect_stats: CollectStats) -> bool: + """ + If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. + Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return + on it. + Finally, if the latter is also True, will set should_stop_training to True. + + :param collect_stats: the data collection stats + :return: flag indicating whether to stop training + """ + should_stop_training = False if ( - result.n_collected_episodes > 0 + collect_stats.n_collected_episodes > 0 and self.test_in_train and self.stop_fn - and self.stop_fn(result.returns_stat.mean) # type: ignore + and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore ): assert self.test_collector is not None test_result = test_episode( @@ -464,7 +487,8 @@ def train_step(self) -> tuple[CollectStats, bool]: should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std - return result, should_stop_training + + return should_stop_training # TODO: move moving average computation and logging into its own logger # TODO: maybe think about a command line logger instead of always printing data dict From 18f236167f56faa9302082bb2a82274888196b1f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 18:14:26 +0200 Subject: [PATCH 17/32] Fix invalid kwarg --- examples/discrete/discrete_dqn_hl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 35e359770..eacf4c78f 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -18,7 +18,6 @@ def main() -> None: DQNExperimentBuilder( EnvFactoryRegistered( task="CartPole-v1", - seed=0, venv_type=VectorEnvType.DUMMY, train_seed=0, test_seed=10, From ca69e79b4a5edbad6d93462d86bda5027a896528 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 18:31:03 +0200 Subject: [PATCH 18/32] Change the way in which deterministic evaluation is controlled: * Remove flag `eval_mode` from Collector.collect * Replace flag `is_eval` in BasePolicy with `is_within_training_step` (negating usages) and set it appropriately in BaseTrainer --- examples/atari/atari_c51.py | 8 +-- examples/atari/atari_dqn.py | 8 +-- examples/atari/atari_fqf.py | 8 +-- examples/atari/atari_iqn.py | 8 +-- examples/atari/atari_ppo.py | 8 +-- examples/atari/atari_qrdqn.py | 8 +-- examples/atari/atari_rainbow.py | 8 +-- examples/atari/atari_sac.py | 8 +-- examples/box2d/acrobot_dualdqn.py | 6 +- examples/box2d/bipedal_bdq.py | 6 +- examples/box2d/bipedal_hardcore_sac.py | 6 +- examples/box2d/lunarlander_dqn.py | 6 +- examples/box2d/mcc_sac.py | 6 +- examples/inverse/irl_gail.py | 6 +- examples/mujoco/fetch_her_ddpg.py | 6 +- examples/mujoco/mujoco_a2c.py | 6 +- examples/mujoco/mujoco_ddpg.py | 6 +- examples/mujoco/mujoco_npg.py | 6 +- examples/mujoco/mujoco_ppo.py | 6 +- examples/mujoco/mujoco_redq.py | 6 +- examples/mujoco/mujoco_reinforce.py | 6 +- examples/mujoco/mujoco_sac.py | 6 +- examples/mujoco/mujoco_td3.py | 6 +- examples/mujoco/mujoco_trpo.py | 6 +- examples/offline/atari_bcq.py | 2 +- examples/offline/atari_cql.py | 2 +- examples/offline/atari_crr.py | 2 +- examples/offline/atari_il.py | 2 +- examples/offline/d4rl_bcq.py | 8 +-- examples/offline/d4rl_cql.py | 8 +-- examples/offline/d4rl_il.py | 8 +-- examples/offline/d4rl_td3_bc.py | 8 +-- examples/vizdoom/vizdoom_c51.py | 8 +-- examples/vizdoom/vizdoom_ppo.py | 8 +-- test/base/test_env_finite.py | 4 +- test/base/test_policy.py | 2 +- test/offline/gather_cartpole_data.py | 2 +- test/offline/test_bcq.py | 2 +- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/data/collector.py | 19 +---- tianshou/highlevel/agent.py | 5 +- tianshou/highlevel/experiment.py | 2 +- tianshou/policy/base.py | 14 +++- tianshou/policy/modelfree/discrete_sac.py | 2 +- tianshou/policy/modelfree/pg.py | 2 +- tianshou/policy/modelfree/redq.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/trainer/base.py | 85 +++++++++++++---------- tianshou/trainer/utils.py | 2 +- 51 files changed, 126 insertions(+), 241 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 98d1d0860..d611ab196 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -174,18 +174,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 7d60654ac..eeb9bccce 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -216,18 +216,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 185cff145..58aff46ac 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -187,18 +187,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 5216d7c4e..c6090523d 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -184,18 +184,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 969d00aaf..dd75de7fb 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -240,18 +240,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 5b0258108..b9731316e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -178,18 +178,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5bb69a38b..952d35f07 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -214,18 +214,14 @@ def watch() -> None: beta=args.beta, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index be13884e9..4d01a88aa 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -227,18 +227,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ac2c800d6..365c073fa 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -147,11 +147,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 7c89cd276..c817831b1 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -165,11 +165,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 214721356..66e5f316d 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -209,11 +209,7 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 8280e236f..f9bbd6fa6 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -144,11 +144,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 97f714983..7617b7b43 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -155,11 +155,7 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index e30b3f651..42e5bc2c9 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -266,11 +266,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 8b3439a8f..bbf68c2fa 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -240,11 +240,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 834e0ced5..194d9b5de 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -221,11 +221,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 066c7d4c9..db90babb0 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -170,11 +170,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 40cc443df..4d8530a53 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -218,11 +218,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index eb6817ef9..7c3f268c8 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -226,11 +226,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 9101f8634..8951b03ac 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -198,11 +198,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index cbd4d3e1d..ff7e34099 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -198,11 +198,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 8d9d61ec1..af1398380 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -192,11 +192,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index d1b5e6921..6cc8bb212 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -190,11 +190,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b5a91c0b5..eefdfcc65 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -223,11 +223,7 @@ def save_best_fn(policy: BasePolicy) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index d4e985ea8..3af40cc7f 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -191,7 +191,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 69ae433ff..b2c0c8705 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -175,7 +175,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 4e0771c0a..8b6320a79 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -191,7 +191,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 928112c77..39aee31d5 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -148,7 +148,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 308a872eb..9ed18262a 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -207,7 +207,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -230,11 +230,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 9d95d29ca..90d6b159c 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -345,7 +345,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -368,11 +368,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index d64cfe9da..e03deed80 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -143,7 +143,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -166,11 +166,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 5719eba16..6b448b320 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -192,7 +192,7 @@ def watch() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -218,11 +218,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 8a52e5f5d..25ad80487 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -180,18 +180,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index adbfb0584..7fc09f690 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -242,18 +242,14 @@ def watch() -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index ce8a93640..35bf3e245 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None: envs.tracker = MetricTracker() try: # TODO: why on earth 10**18? - test_collector.collect(n_step=10**18) + test_collector.collect(n_step=10 ** 18) except StopIteration: envs.tracker.validate() @@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None: for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10**18) + test_collector.collect(n_step=10 ** 18) except StopIteration: envs.tracker.validate() diff --git a/test/base/test_policy.py b/test/base/test_policy.py index f286156ed..4d26905c3 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -64,7 +64,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: class TestPolicyBasics: def test_get_action(self, policy: PPOPolicy) -> None: - policy.is_eval = True + policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 7d6aba1b2..19ba653e5 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -167,7 +167,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) collector.reset() - collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True) + collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 20e4dd6c6..8b31c1969 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -185,7 +185,7 @@ def watch() -> None: torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) # trainer result = OfflineTrainer( diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 27392f740..c57522df0 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy, _, _ = get_agents(args) [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, eval_mode=True) + result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index ed085225a..38de81173 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non ) policy, _, _ = get_agents(args) collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True) + collector_result = collector.collect(n_episode=1, render=args.render) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index fc46e2a8f..966c9e04c 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -230,5 +230,5 @@ def watch( policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, eval_mode=True) + result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 10cf663d2..cf897a59c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -260,17 +260,8 @@ def _collect( ) -> CollectStats: pass - def collect( - self, - n_step: int | None = None, - n_episode: int | None = None, - random: bool = False, - render: float | None = None, - no_grad: bool = True, - reset_before_collect: bool = False, - gym_reset_kwargs: dict[str, Any] | None = None, - eval_mode: bool = False, - ) -> CollectStats: + def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, + no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats: """Collect a specified number of steps or episodes. To ensure an unbiased sampling result with the n_episode option, this function will @@ -286,9 +277,6 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. - :param eval_mode: whether to collect data in evaluation mode. Will - set the policy to training mode otherwise. - .. note:: One and only one collection number specification is permitted, either @@ -302,8 +290,7 @@ def collect( if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - policy_mode_context = in_eval_mode if eval_mode else in_train_mode - with policy_mode_context(self.policy): + with in_eval_mode(self.policy): # safety precaution only return self._collect( n_step=n_step, n_episode=n_episode, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index c1313262e..fdfc4c08f 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -130,10 +130,7 @@ def create_train_test_collector( log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", ) - train_collector.collect( - n_step=self.sampling_config.start_timesteps, - random=self.sampling_config.start_timesteps_random, - ) + train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random) return train_collector, test_collector def set_policy_wrapper_factory( diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index df87aca34..e7ccc9f8a 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -337,7 +337,7 @@ def _watch_agent( ) -> None: collector = Collector(policy, env) collector.reset() - result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True) + result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e566d0c33..498a33352 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -226,8 +226,18 @@ def __init__( self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.lr_scheduler = lr_scheduler - # whether the policy is in evaluation mode - self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward? + self.is_within_training_step = False + """ + flag indicating whether we are currently within a training step, which encompasses data collection + for training and the policy update (gradient steps). + + It can be used, for example, to control whether a flag controlling deterministic evaluation should + indeed be applied, because within a training step, we typically always want to apply stochastic evaluation + (even if such a flag is enabled). + + This flag should normally remain False and should be set to True only by the algorithm which performs + training steps. + """ self._compile() @property diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 7e731b1c3..8c7942858 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,7 +107,7 @@ def forward( # type: ignore ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() + act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 4792db8f5..3ef82be56 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,7 +197,7 @@ def forward( # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() + act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index a216cf9fe..25f299733 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -153,7 +153,7 @@ def forward( # type: ignore ) -> Batch: (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and self.is_eval: + if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3dbea7508..a5a05c0fd 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -175,7 +175,7 @@ def forward( # type: ignore ) -> DistLogProbBatchProtocol: (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and self.is_eval: + if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 825c80ceb..213da7b04 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable +from contextlib import contextmanager from dataclasses import asdict from typing import Optional, Tuple @@ -407,23 +408,34 @@ def test_step(self) -> tuple[CollectStats, bool]: return test_stat, stop_fn_flag + @contextmanager + def _is_within_training_step_enabled(self, is_within_training_step: bool): + old_value = self.policy.is_within_training_step + try: + self.policy.is_within_training_step = is_within_training_step + yield + finally: + self.policy.is_within_training_step = old_value + def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: - should_stop_training = False + with self._is_within_training_step_enabled(True): - if self.train_collector is not None: - collect_stats = self._collect_training_data() - should_stop_training = self._test_in_train(collect_stats) - else: - collect_stats = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) + should_stop_training = False + + if self.train_collector is not None: + collect_stats = self._collect_training_data() + should_stop_training = self._test_in_train(collect_stats) + else: + collect_stats = CollectStatsBase( + n_collected_episodes=len(self.buffer), + ) - if not should_stop_training: - training_stats = self.policy_update_fn(collect_stats) - else: - training_stats = None + if not should_stop_training: + training_stats = self.policy_update_fn(collect_stats) + else: + training_stats = None - return collect_stats, training_stats, should_stop_training + return collect_stats, training_stats, should_stop_training def _collect_training_data(self) -> CollectStats: """Performs training data collection @@ -434,10 +446,7 @@ def _collect_training_data(self) -> CollectStats: assert self.train_collector is not None if self.train_fn: self.train_fn(self.epoch, self.env_step) - collect_stats = self.train_collector.collect( - n_step=self.step_per_collect, - n_episode=self.episode_per_collect, - ) + collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect) self.env_step += collect_stats.n_collected_steps @@ -467,26 +476,28 @@ def _test_in_train(self, collect_stats: CollectStats) -> bool: """ should_stop_training = False - if ( - collect_stats.n_collected_episodes > 0 - and self.test_in_train - and self.stop_fn - and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore - ): - assert self.test_collector is not None - test_result = test_episode( - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, - ) - assert test_result.returns_stat is not None # for mypy - if self.stop_fn(test_result.returns_stat.mean): - should_stop_training = True - self.best_reward = test_result.returns_stat.mean - self.best_reward_std = test_result.returns_stat.std + # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics + with self._is_within_training_step_enabled(False): + if ( + collect_stats.n_collected_episodes > 0 + and self.test_in_train + and self.stop_fn + and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore + ): + assert self.test_collector is not None + test_result = test_episode( + self.test_collector, + self.test_fn, + self.epoch, + self.episode_per_test, + self.logger, + self.env_step, + ) + assert test_result.returns_stat is not None # for mypy + if self.stop_fn(test_result.returns_stat.mean): + should_stop_training = True + self.best_reward = test_result.returns_stat.mean + self.best_reward_std = test_result.returns_stat.std return should_stop_training diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 0c2bf1896..767e76dab 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -27,7 +27,7 @@ def test_episode( collector.reset(reset_stats=False) if test_fn: test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode, eval_mode=True) + result = collector.collect(n_episode=n_episode) if reward_metric: # TODO: move into collector rew = reward_metric(result.returns) result.returns = rew From c35be8d07ef738c37e6c246f70ffc950339a4afe Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 18:47:42 +0200 Subject: [PATCH 19/32] Establish backward compatibility by implementing __setstate__ --- tianshou/policy/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 498a33352..2703de7ea 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict import gymnasium as gym import numpy as np @@ -240,6 +240,12 @@ def __init__( """ self._compile() + def __setstate__(self, state: Dict[str, Any]) -> None: + # TODO Use setstate function once merged + if "is_within_training_step" not in state: + state["is_within_training_step"] = False + self.__dict__ = state + @property def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type From 6927eadaa79fa89948cca164c1e8a4add09ff36d Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 15:14:59 +0200 Subject: [PATCH 20/32] BatchPolicy: check that `self.is_within_training_step` is True on update --- tianshou/policy/base.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 2703de7ea..8753af66e 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np @@ -214,13 +214,13 @@ def __init__( super().__init__() self.observation_space = observation_space self.action_space = action_space - self._action_type: Literal["discrete", "continuous"] if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): - self._action_type = "discrete" + action_type = "discrete" elif isinstance(action_space, Box): - self._action_type = "continuous" + action_type = "continuous" else: raise ValueError(f"Unsupported action space: {action_space}.") + self._action_type = cast(Literal["discrete", "continuous"], action_type) self.agent_id = 0 self.updating = False self.action_scaling = action_scaling @@ -228,19 +228,22 @@ def __init__( self.lr_scheduler = lr_scheduler self.is_within_training_step = False """ - flag indicating whether we are currently within a training step, which encompasses data collection - for training and the policy update (gradient steps). - - It can be used, for example, to control whether a flag controlling deterministic evaluation should + flag indicating whether we are currently within a training step, + which encompasses data collection for training (in online RL algorithms) + and the policy update (gradient steps). + + It can be used, for example, to control whether a flag controlling deterministic evaluation should indeed be applied, because within a training step, we typically always want to apply stochastic evaluation - (even if such a flag is enabled). - + (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC + based algorithms). + This flag should normally remain False and should be set to True only by the algorithm which performs - training steps. + training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, + the user should ensure that this flag is set correctly before calling update or learn. """ self._compile() - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: # TODO Use setstate function once merged if "is_within_training_step" not in state: state["is_within_training_step"] = False @@ -524,6 +527,14 @@ def update( """ # TODO: when does this happen? # -> this happens never in practice as update is either called with a collector buffer or an assert before + + if not self.is_within_training_step: + raise RuntimeError( + f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " + f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " + f"flag yourself.", + ) + if buffer is None: return TrainingStats() # type: ignore[return-value] start_time = time.time() From f876198870f8c8554bd15f1f5c928155eaaac584 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 15:16:16 +0200 Subject: [PATCH 21/32] Formatting --- test/base/test_env_finite.py | 4 ++-- tianshou/data/collector.py | 14 +++++++++++--- tianshou/highlevel/agent.py | 5 ++++- tianshou/policy/modelfree/discrete_sac.py | 6 +++++- tianshou/policy/modelfree/pg.py | 6 +++++- tianshou/trainer/base.py | 15 +++++++-------- 6 files changed, 34 insertions(+), 16 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 35bf3e245..ce8a93640 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None: envs.tracker = MetricTracker() try: # TODO: why on earth 10**18? - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() @@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None: for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index cf897a59c..133667f82 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -27,7 +27,7 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin -from tianshou.utils.torch_utils import in_eval_mode, in_train_mode +from tianshou.utils.torch_utils import in_eval_mode log = logging.getLogger(__name__) @@ -260,8 +260,16 @@ def _collect( ) -> CollectStats: pass - def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats: + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: """Collect a specified number of steps or episodes. To ensure an unbiased sampling result with the n_episode option, this function will diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index fdfc4c08f..c1313262e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -130,7 +130,10 @@ def create_train_test_collector( log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", ) - train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random) + train_collector.collect( + n_step=self.sampling_config.start_timesteps, + random=self.sampling_config.start_timesteps_random, + ) return train_collector, test_collector def set_policy_wrapper_factory( diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 8c7942858..d1ce28da9 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,7 +107,11 @@ def forward( # type: ignore ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 3ef82be56..80bcff672 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,7 +197,11 @@ def forward( # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 213da7b04..6e463741e 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -5,7 +5,6 @@ from collections.abc import Callable from contextlib import contextmanager from dataclasses import asdict -from typing import Optional, Tuple import numpy as np import tqdm @@ -305,7 +304,6 @@ def __next__(self) -> EpochStats: with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - train_stat, update_stat, self.stop_fn_flag = self.training_step() if isinstance(train_stat, CollectStats): @@ -417,9 +415,8 @@ def _is_within_training_step_enabled(self, is_within_training_step: bool): finally: self.policy.is_within_training_step = old_value - def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: + def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: with self._is_within_training_step_enabled(True): - should_stop_training = False if self.train_collector is not None: @@ -438,7 +435,7 @@ def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool return collect_stats, training_stats, should_stop_training def _collect_training_data(self) -> CollectStats: - """Performs training data collection + """Performs training data collection. :return: the data collection stats """ @@ -446,7 +443,10 @@ def _collect_training_data(self) -> CollectStats: assert self.train_collector is not None if self.train_fn: self.train_fn(self.epoch, self.env_step) - collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect) + collect_stats = self.train_collector.collect( + n_step=self.step_per_collect, + n_episode=self.episode_per_collect, + ) self.env_step += collect_stats.n_collected_steps @@ -465,8 +465,7 @@ def _collect_training_data(self) -> CollectStats: return collect_stats def _test_in_train(self, collect_stats: CollectStats) -> bool: - """ - If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. + """If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return on it. Finally, if the latter is also True, will set should_stop_training to True. From c5d0e169b5f22336ef3b9679e356f27a3b370b58 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 15:41:20 +0200 Subject: [PATCH 22/32] Collector: removed unnecessary no-grad flag from interfaces. Breaking --- tianshou/data/collector.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 133667f82..cfc4b3db6 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -255,18 +255,17 @@ def _collect( n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: pass + @torch.no_grad() def collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: @@ -304,7 +303,6 @@ def collect( n_episode=n_episode, random=random, render=render, - no_grad=no_grad, gym_reset_kwargs=gym_reset_kwargs, ) @@ -398,7 +396,6 @@ def _compute_action_policy_hidden( self, random: bool, ready_env_ids_R: np.ndarray, - use_grad: bool, last_obs_RO: np.ndarray, last_info_R: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, @@ -420,11 +417,10 @@ def _compute_action_policy_hidden( info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - with torch.set_grad_enabled(use_grad): - act_batch_RA = self.policy( - obs_batch_R, - last_hidden_state_RH, - ) + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: @@ -454,7 +450,6 @@ def _collect( n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: # TODO: can't do it init since AsyncCollector is currently a subclass of Collector @@ -469,8 +464,6 @@ def _collect( elif n_episode is not None: ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - use_grad = not no_grad - start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( @@ -513,7 +506,6 @@ def _collect( ) = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, - use_grad=use_grad, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, @@ -762,10 +754,8 @@ def _collect( n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - use_grad = not no_grad start_time = time.time() step_count = 0 @@ -823,7 +813,6 @@ def _collect( ) = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, - use_grad=use_grad, last_obs_RO=last_obs_RO, last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, From 26a6cca76e8c10353b898bc60ba9500df6613d5f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 15:56:06 +0200 Subject: [PATCH 23/32] Improved docstrings, added asserts to make mypy happy --- tianshou/data/collector.py | 6 ------ tianshou/trainer/base.py | 14 +++++++++----- tianshou/trainer/utils.py | 20 +++----------------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index cfc4b3db6..b498f45a7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -279,7 +279,6 @@ def collect( :param n_episode: how many episodes you want to collect. :param random: whether to use random policy for collecting data. :param render: the sleep time between rendering consecutive frames. - :param no_grad: whether to retain gradient in policy.forward(). :param reset_before_collect: whether to reset the environment before collecting data. (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's @@ -343,11 +342,6 @@ class Collector(BaseCollector): # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. - - # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy - # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on - # policy.deterministic_eval) - def __init__( self, policy: BasePolicy, diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 6e463741e..ef9a154b6 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -2,7 +2,7 @@ import time from abc import ABC, abstractmethod from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Iterator from contextlib import contextmanager from dataclasses import asdict @@ -360,7 +360,10 @@ def __next__(self) -> EpochStats: self.logger.log_info_data(asdict(info_stat), self.epoch) # in case trainer is used with run(), epoch_stat will not be returned - epoch_stat: EpochStats = EpochStats( + assert ( + update_stat is not None + ), "Defined in the loop above, this shouldn't have happened and is likely a bug!" + return EpochStats( epoch=self.epoch, train_collect_stat=train_stat, test_collect_stat=test_stat, @@ -368,8 +371,6 @@ def __next__(self) -> EpochStats: info_stat=info_stat, ) - return epoch_stat - def test_step(self) -> tuple[CollectStats, bool]: """Perform one testing step.""" assert self.episode_per_test is not None @@ -407,7 +408,7 @@ def test_step(self) -> tuple[CollectStats, bool]: return test_stat, stop_fn_flag @contextmanager - def _is_within_training_step_enabled(self, is_within_training_step: bool): + def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]: old_value = self.policy.is_within_training_step try: self.policy.is_within_training_step = is_within_training_step @@ -419,10 +420,12 @@ def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: with self._is_within_training_step_enabled(True): should_stop_training = False + collect_stats: CollectStatsBase | CollectStats if self.train_collector is not None: collect_stats = self._collect_training_data() should_stop_training = self._test_in_train(collect_stats) else: + assert self.buffer is not None, "Either train_collector or buffer must be provided." collect_stats = CollectStatsBase( n_collected_episodes=len(self.buffer), ) @@ -484,6 +487,7 @@ def _test_in_train(self, collect_stats: CollectStats) -> bool: and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore ): assert self.test_collector is not None + assert self.episode_per_test is not None and self.episode_per_test > 0 test_result = test_episode( self.test_collector, self.test_fn, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 767e76dab..de730cee2 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -49,23 +49,9 @@ def gather_info( ) -> InfoStats: """A simple wrapper of gathering information from collectors. - :return: A dataclass object with the following members (depending on available collectors): - - * ``gradient_step`` the total number of gradient steps; - * ``best_reward`` the best reward over the test results; - * ``best_reward_std`` the standard deviation of best reward over the test results; - * ``train_step`` the total collected step of training collector; - * ``train_episode`` the total collected episode of training collector; - * ``test_step`` the total collected step of test collector; - * ``test_episode`` the total collected episode of test collector; - * ``timing`` the timing statistics, with the following members: - * ``total_time`` the total time elapsed; - * ``train_time`` the total time elapsed for learning training (collecting samples plus model update); - * ``train_time_collect`` the time for collecting transitions in the \ - training collector; - * ``train_time_update`` the time for training models; - * ``test_time`` the time for testing; - * ``update_speed`` the speed of updating (env_step per second). + :return: InfoStats object with times computed based on the `start_time` and + episode/step counts read off the collectors. No computation of + expensive statistics is done here. """ duration = max(0.0, time.time() - start_time) test_time = 0.0 From 82f425e9feb64856e4e4cb667e14e06936b93a5a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 16:01:52 +0200 Subject: [PATCH 24/32] Collector: move @override, removed docstrings from overridden methods --- tianshou/data/collector.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b498f45a7..d2ca2178c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -237,7 +237,10 @@ def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environments and the initial obs, info, and hidden state of the collector.""" + """Reset the environments and the initial obs, info, and hidden state of the collector. + + :return: The initial observation and info from the (vectorized) environment. + """ gym_reset_kwargs = gym_reset_kwargs or {} obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict @@ -368,16 +371,17 @@ def __init__( self._is_closed = False self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + @override def close(self) -> None: super().close() self._pre_collect_obs_RO = None self._pre_collect_info_R = None + @override def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environments and the initial obs, info, and hidden state of the collector.""" obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) # We assume that R = N when reset is called. # TODO: there is currently no mechanism that ensures this and it's a public method! @@ -457,6 +461,8 @@ def _collect( ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + else: + raise ValueError("Either n_step or n_episode should be set.") start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: @@ -645,8 +651,8 @@ def _collect( collect_speed=step_count / collect_time, ) + @staticmethod def _reset_hidden_state_based_on_type( - self, env_ind_local_D: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, ) -> None: @@ -700,21 +706,13 @@ def __init__( self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) self._current_policy_in_all_envs_E: Batch | None = None + @override def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environment, statistics, and data needed to start the collection. - - :param reset_buffer: if true, reset the replay buffer attached - to the collector. - :param reset_stats: if true, reset the statistics attached to the collector. - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - :return: The initial observation and info from the environment. - """ # This sets the _pre_collect attrs result = super().reset( reset_buffer=reset_buffer, From a8e9df31f7e0f81ca7024d4aa7be8f69d4992e92 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 22:08:22 +0200 Subject: [PATCH 25/32] Bugfix: allow for training_stat to be None instead of asserting not-None --- tianshou/data/stats.py | 5 +++-- tianshou/trainer/base.py | 29 +++++++++++++++++++++-------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index b1ce8362e..b77318602 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin): """The statistics of the last call to the training collector.""" test_collect_stat: Optional["CollectStats"] """The statistics of the last call to the test collector.""" - training_stat: "TrainingStats" - """The statistics of the last model update step.""" + training_stat: Optional["TrainingStats"] + """The statistics of the last model update step. + Can be None if no model update is performed, typically in the last training iteration.""" info_stat: InfoStats """The information of the collector.""" diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index ef9a154b6..3ad80c8c8 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -360,9 +360,6 @@ def __next__(self) -> EpochStats: self.logger.log_info_data(asdict(info_stat), self.epoch) # in case trainer is used with run(), epoch_stat will not be returned - assert ( - update_stat is not None - ), "Defined in the loop above, this shouldn't have happened and is likely a bug!" return EpochStats( epoch=self.epoch, train_collect_stat=train_stat, @@ -417,13 +414,23 @@ def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Ite self.policy.is_within_training_step = old_value def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: + """Perform one training iteration. + + A training iteration includes collecting data (for online RL), determining whether to stop training, + and peforming a policy update if the training iteration should continue. + + :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. + If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. + """ with self._is_within_training_step_enabled(True): should_stop_training = False collect_stats: CollectStatsBase | CollectStats if self.train_collector is not None: collect_stats = self._collect_training_data() - should_stop_training = self._test_in_train(collect_stats) + should_stop_training = self._update_best_reward_and_return_should_stop_training( + collect_stats, + ) else: assert self.buffer is not None, "Either train_collector or buffer must be provided." collect_stats = CollectStatsBase( @@ -467,11 +474,17 @@ def _collect_training_data(self) -> CollectStats: return collect_stats - def _test_in_train(self, collect_stats: CollectStats) -> bool: - """If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. - Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return + def _update_best_reward_and_return_should_stop_training( + self, + collect_stats: CollectStats, + ) -> bool: + """If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data. + Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return on it. - Finally, if the latter is also True, will set should_stop_training to True. + Finally, if the latter is also True, will return True. + + **NOTE:** has a side effect of updating the best reward and corresponding std. + :param collect_stats: the data collection stats :return: flag indicating whether to stop training From 35779696ee860684d1e0dfa229cc8f367ef5118f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 30 Apr 2024 16:12:43 +0200 Subject: [PATCH 26/32] Clean up handling of an Experiment's name (and, by extension, a run's name) --- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_iqn_hl.py | 2 +- examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_sac_hl.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- test/highlevel/test_experiment_builder.py | 4 +- tianshou/highlevel/experiment.py | 67 ++++++++++++----------- 15 files changed, 51 insertions(+), 46 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 289363e1e..aa76983be 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -104,7 +104,7 @@ def main( ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 850c0ffa4..23df1cd25 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -96,7 +96,7 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index ea45df556..10dcd0a7e 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -115,7 +115,7 @@ def main( ), ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 8b1bf2825..cf09b40ea 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -103,7 +103,7 @@ def main( ), ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 96ad8c584..bce02e9c0 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index a476245ab..db9c4e3e2 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -74,7 +74,7 @@ def main( .with_critic_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 2e437caca..ab265a87a 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -85,7 +85,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 601b08413..27a701b12 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -95,7 +95,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 9b4bca75b..f52372906 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_ensemble_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index a5ec65f9a..46eb64fa2 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -72,7 +72,7 @@ def main( .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 9ffa0f43c..5ca731868 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -80,7 +80,7 @@ def main( .with_common_critic_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 6adc73d26..3a32c7f42 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -85,7 +85,7 @@ def main( .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 3af69bd45..f54d4c312 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -89,7 +89,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 725d7f7b5..0ba8a7bac 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime sampling_config=sampling_config, ) experiment = builder.build() - experiment.run(override_experiment_name="test") + experiment.run(run_name="test") print(experiment) @@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment sampling_config=sampling_config, ) experiment = builder.build() - experiment.run(override_experiment_name="test") + experiment.run(run_name="test") print(experiment) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index e7ccc9f8a..5ae300509 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -6,7 +6,7 @@ from copy import copy from dataclasses import dataclass from pprint import pformat -from typing import Literal, Self +from typing import Self, Dict, Any import numpy as np import torch @@ -80,7 +80,7 @@ ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.utils import LazyLogger, logging +from tianshou.utils import LazyLogger, deprecation, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin @@ -145,8 +145,8 @@ def __init__( env_factory: EnvFactory, agent_factory: AgentFactory, sampling_config: SamplingConfig, + name: str, logger_factory: LoggerFactory | None = None, - name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", ): if logger_factory is None: logger_factory = LoggerFactoryDefault() @@ -155,8 +155,6 @@ def __init__( self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory - if name == "DATETIME_TAG": - name = datetime_tag() self.name = name def get_seeding_info_as_str(self) -> str: @@ -205,33 +203,41 @@ def save(self, directory: str) -> None: def run( self, - override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, + run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, + **kwargs: Dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. - :param override_experiment_name: if not None, will adjust the current instance's `name` name attribute. - The name corresponds to the directory (within the logging - directory) where all results associated with the experiment will be saved. + :param run_name: Defines a name for this run of the experiment, which determines + the subdirectory (within the persistence base directory) where all results will be saved. + If None, the experiment's name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. - If "DATETIME_TAG" is passed, use a name containing the current date and time. This option - is useful for preventing file-name collisions if a single experiment is executed repeatedly. :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when using wandb, in particular). :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed experiment with the same name. + :param kwargs: for backward compatibility with old parameter names only :return: """ - if override_experiment_name is not None: - if override_experiment_name == "DATETIME_TAG": - override_experiment_name = datetime_tag() - self.name = override_experiment_name + # backward compatibility + _experiment_name = kwargs.pop("experiment_name", None) + if _experiment_name is not None: + run_name = _experiment_name + deprecation( + "Parameter run_name should now be used instead of experiment_name. " + "Support for experiment_name will be removed in the future.", + ) + assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" + + if run_name is None: + run_name = self.name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) + persistence_dir = os.path.join(self.config.persistence_base_dir, run_name) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) @@ -240,7 +246,7 @@ def run( enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") + log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -271,7 +277,7 @@ def run( if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=self.name, + experiment_name=run_name, run_id=logger_run_id, config_dict=full_config, ) @@ -363,7 +369,7 @@ def __init__( self._optim_factory: OptimizerFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() - self._experiment_name: str = "" + self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @contextmanager def temp_config_mutation(self) -> Iterator[Self]: @@ -466,18 +472,17 @@ def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self: self._trainer_callbacks.epoch_stop_callback = callback return self - def with_experiment_name( + def with_name( self, - experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", + name: str, ) -> Self: """Sets the name of the experiment. - :param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used. + :param name: the name to use for this experiment, which, when the experiment is run, + will determine the storage sub-folder by default :return: the builder """ - if experiment_name == "DATETIME_TAG": - experiment_name = datetime_tag() - self._experiment_name = experiment_name + self._name = name return self @abstractmethod @@ -503,12 +508,12 @@ def build(self, add_seeding_info_to_name: bool = False) -> Experiment: if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) experiment: Experiment = Experiment( - self._config, - self._env_factory, - agent_factory, - self._sampling_config, - self._logger_factory, - name=self._experiment_name, + config=self._config, + env_factory=self._env_factory, + agent_factory=agent_factory, + sampling_config=self._sampling_config, + name=self._name, + logger_factory=self._logger_factory, ) if add_seeding_info_to_name: if not experiment.name: From 024b80e79ccdf1d1986c2b10e7761a2b522c6abd Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 30 Apr 2024 17:22:11 +0200 Subject: [PATCH 27/32] Improve creation of multiple seeded experiments: * Add class ExperimentCollection to improve usability * Remove parameters from ExperimentBuilder.build * Renamed ExperimentBuilder.build_default_seeded_experiments to build_seeded_collection, changing the return type to ExperimentCollection * Replace temp_config_mutation (which was not appropriate for the public API) with method copy (which performs a safe deep copy) --- examples/mujoco/mujoco_ppo_hl_multi.py | 135 ++++++---------------- test/highlevel/test_experiment_builder.py | 27 ----- tianshou/highlevel/experiment.py | 92 ++++++++------- 3 files changed, 83 insertions(+), 171 deletions(-) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 4408a132c..319375f12 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -14,8 +14,6 @@ import os import sys -from collections.abc import Sequence -from typing import Literal import torch @@ -41,86 +39,30 @@ def main( - experiment_config: ExperimentConfig, - task: str = "Ant-v4", - num_experiments: int = 5, - buffer_size: int = 4096, - hidden_sizes: Sequence[int] = (64, 64), - lr: float = 3e-4, - gamma: float = 0.99, - epoch: int = 3, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 10, - batch_size: int = 64, - training_num: int = 10, - test_num: int = 10, - rew_norm: bool = True, - vf_coef: float = 0.25, - ent_coef: float = 0.0, - gae_lambda: float = 0.95, - bound_action_method: Literal["clip", "tanh"] | None = "clip", - lr_decay: bool = True, - max_grad_norm: float = 0.5, - eps_clip: float = 0.2, - dual_clip: float | None = None, - value_clip: bool = False, - norm_adv: bool = False, - recompute_adv: bool = True, + num_experiments: int = 2, run_experiments_sequentially: bool = True, -) -> str: - """Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for - a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained, - the results are evaluated using the rliable API. - - :param experiment_config: - :param task: a mujoco task name - :param num_experiments: how many experiments to run with different seeds - :param buffer_size: - :param hidden_sizes: - :param lr: - :param gamma: - :param epoch: - :param step_per_epoch: - :param step_per_collect: - :param repeat_per_collect: - :param batch_size: - :param training_num: - :param test_num: - :param rew_norm: - :param vf_coef: - :param ent_coef: - :param gae_lambda: - :param bound_action_method: - :param lr_decay: - :param max_grad_norm: - :param eps_clip: - :param dual_clip: - :param value_clip: - :param norm_adv: - :param recompute_adv: - :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. +) -> RLiableExperimentResult: + """:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. It might generally be undesired to run multiple experiments in parallel on the same machine, as a single experiment already uses all available CPU cores by default. :return: the directory where the results are stored """ + task = "Ant-v4" persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) - experiment_config.persistence_base_dir = persistence_dir - log.info(f"Will save all experiment results to {persistence_dir}.") - experiment_config.watch = False + experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, - num_test_episodes=test_num, - buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + num_epochs=1, + step_per_epoch=5000, + batch_size=64, + num_train_envs=10, + num_test_envs=10, + num_test_episodes=10, + buffer_size=4096, + step_per_collect=2048, + repeat_per_collect=10, ) env_factory = MujocoEnvFactory( @@ -133,52 +75,45 @@ def main( else VectorEnvType.SUBPROC_SHARED_MEM, ) - experiments = ( + hidden_sizes = (64, 64) + + experiment_collection = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( PPOParams( - discount_factor=gamma, - gae_lambda=gae_lambda, - action_bound_method=bound_action_method, - reward_normalization=rew_norm, - ent_coef=ent_coef, - vf_coef=vf_coef, - max_grad_norm=max_grad_norm, - value_clip=value_clip, - advantage_normalization=norm_adv, - eps_clip=eps_clip, - dual_clip=dual_clip, - recompute_advantage=recompute_adv, - lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + discount_factor=0.99, + gae_lambda=0.95, + action_bound_method="clip", + reward_normalization=True, + ent_coef=0.0, + vf_coef=0.25, + max_grad_norm=0.5, + value_clip=False, + advantage_normalization=False, + eps_clip=0.2, + dual_clip=None, + recompute_advantage=True, + lr=3e-4, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config), dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_logger_factory(LoggerFactoryDefault("tensorboard")) - .build_default_seeded_experiments(num_experiments) + .build_seeded_collection(num_experiments) ) if run_experiments_sequentially: launcher = RegisteredExpLauncher.sequential.create_launcher() else: launcher = RegisteredExpLauncher.joblib.create_launcher() - launcher.launch(experiments) - - return persistence_dir - + experiment_collection.run(launcher) -def eval_experiments(log_dir: str) -> RLiableExperimentResult: - """Evaluate the experiments in the given log directory using the rliable API.""" - rliable_result = RLiableExperimentResult.load_from_disk(log_dir) + rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir) rliable_result.eval_results(show_plots=True, save_plots=True) return rliable_result if __name__ == "__main__": - log_dir = logging.run_cli(main, level=logging.INFO) - assert isinstance(log_dir, str) # for mypy - evaluation_result = eval_experiments(log_dir) + result = logging.run_cli(main, level=logging.INFO) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 0ba8a7bac..cb52c5ae3 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -79,30 +79,3 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment experiment = builder.build() experiment.run(run_name="test") print(experiment) - - -def test_temp_builder_modification() -> None: - env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig( - num_epochs=1, - step_per_epoch=100, - num_train_envs=2, - num_test_envs=2, - ) - builder = PPOExperimentBuilder( - experiment_config=ExperimentConfig(persistence_enabled=False), - env_factory=env_factory, - sampling_config=sampling_config, - ) - original_seed = builder.experiment_config.seed - original_train_seed = builder.sampling_config.train_seed - - with builder.temp_config_mutation(): - builder.experiment_config.seed += 12345 - builder.sampling_config.train_seed += 456 - exp = builder.build() - - assert builder.experiment_config.seed == original_seed - assert builder.sampling_config.train_seed == original_train_seed - assert exp.config.seed == original_seed + 12345 - assert exp.sampling_config.train_seed == original_train_seed + 456 diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 5ae300509..99aadc23f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,12 +1,11 @@ import os import pickle from abc import abstractmethod -from collections.abc import Iterator, Sequence -from contextlib import contextmanager -from copy import copy +from collections.abc import Sequence +from copy import deepcopy from dataclasses import dataclass from pprint import pformat -from typing import Self, Dict, Any +from typing import TYPE_CHECKING, Any, Self, Union, cast import numpy as np import torch @@ -85,6 +84,10 @@ from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin +if TYPE_CHECKING: + from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher + + log = logging.getLogger(__name__) @@ -157,19 +160,6 @@ def __init__( self.logger_factory = logger_factory self.name = name - def get_seeding_info_as_str(self) -> str: - """Useful for creating unique experiment names based on seeds. - - A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. - """ - return "_".join( - [ - f"exp_seed={self.config.seed}", - f"train_seed={self.sampling_config.train_seed}", - f"test_seed={self.sampling_config.test_seed}", - ], - ) - @classmethod def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": """Restores an experiment from a previously stored pickle. @@ -184,6 +174,20 @@ def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experim experiment.config.policy_restore_directory = directory return experiment + def get_seeding_info_as_str(self) -> str: + """Returns information on the seeds used in the experiment as a string. + + This can be useful for creating unique experiment names based on seeds, e.g. + A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. + """ + return "_".join( + [ + f"exp_seed={self.config.seed}", + f"train_seed={self.sampling_config.train_seed}", + f"test_seed={self.sampling_config.test_seed}", + ], + ) + def _set_seed(self) -> None: seed = self.config.seed log.info(f"Setting random seed {seed}") @@ -206,7 +210,7 @@ def run( run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. @@ -225,7 +229,7 @@ def run( # backward compatibility _experiment_name = kwargs.pop("experiment_name", None) if _experiment_name is not None: - run_name = _experiment_name + run_name = cast(str, _experiment_name) deprecation( "Parameter run_name should now be used instead of experiment_name. " "Support for experiment_name will be removed in the future.", @@ -351,6 +355,18 @@ def _watch_agent( ) +class ExperimentCollection: + def __init__(self, experiments: list[Experiment]): + self.experiments = experiments + + def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None: + from tianshou.evaluation.launcher import RegisteredExpLauncher + + if isinstance(launcher, RegisteredExpLauncher): + launcher = launcher.create_launcher() + launcher.launch(experiments=self.experiments) + + class ExperimentBuilder: def __init__( self, @@ -371,14 +387,8 @@ def __init__( self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() - @contextmanager - def temp_config_mutation(self) -> Iterator[Self]: - """Returns the builder instance where the configs can be modified without affecting the current instance.""" - original_sampling_config = copy(self.sampling_config) - original_experiment_config = copy(self.experiment_config) - yield self - self.sampling_config = original_sampling_config - self.experiment_config = original_experiment_config + def copy(self) -> Self: + return deepcopy(self) @property def experiment_config(self) -> ExperimentConfig: @@ -495,12 +505,9 @@ def _get_optim_factory(self) -> OptimizerFactory: else: return self._optim_factory - def build(self, add_seeding_info_to_name: bool = False) -> Experiment: + def build(self) -> Experiment: """Creates the experiment based on the options specified via this builder. - :param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains - info about the training seeds. Useful for creating multiple experiments that only differ - by seeds. :return: the experiment """ agent_factory = self._create_agent_factory() @@ -515,27 +522,24 @@ def build(self, add_seeding_info_to_name: bool = False) -> Experiment: name=self._name, logger_factory=self._logger_factory, ) - if add_seeding_info_to_name: - if not experiment.name: - experiment.name = experiment.get_seeding_info_as_str() - else: - experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}" return experiment - def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]: - """Creates a list of experiments with non-overlapping seeds, starting from the configured seed. + def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: + """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. - Each experiment will have a unique name that is created from the original experiment name and the seeds used. + Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used. """ num_train_envs = self.sampling_config.num_train_envs seeded_experiments = [] for i in range(num_experiments): - with self.temp_config_mutation(): - self.experiment_config.seed += i - self.sampling_config.train_seed += i * num_train_envs - seeded_experiments.append(self.build(add_seeding_info_to_name=True)) - return seeded_experiments + builder = self.copy() + builder.experiment_config.seed += i + builder.sampling_config.train_seed += i * num_train_envs + experiment = builder.build() + experiment.name += f"_{experiment.get_seeding_info_as_str()}" + seeded_experiments.append(experiment) + return ExperimentCollection(seeded_experiments) class _BuilderMixinActorFactory(ActorFutureProviderProtocol): From 2abb4dac24c41cb58c6c73023ecc6dfeb4c694cd Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 22:23:13 +0200 Subject: [PATCH 28/32] Reinstated warning module --- tianshou/utils/__init__.py | 2 ++ tianshou/utils/warning.py | 8 ++++++++ 2 files changed, 10 insertions(+) create mode 100644 tianshou/utils/warning.py diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 2f46f9ca5..47a3c4497 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -6,11 +6,13 @@ from tianshou.utils.lr_scheduler import MultipleLRSchedulers from tianshou.utils.progress_bar import DummyTqdm, tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd +from tianshou.utils.warning import deprecation __all__ = [ "MovAvg", "RunningMeanStd", "tqdm_config", + "deprecation", "DummyTqdm", "BaseLogger", "TensorboardLogger", diff --git a/tianshou/utils/warning.py b/tianshou/utils/warning.py new file mode 100644 index 000000000..93c5ccec3 --- /dev/null +++ b/tianshou/utils/warning.py @@ -0,0 +1,8 @@ +import warnings + +warnings.simplefilter("once", DeprecationWarning) + + +def deprecation(msg: str) -> None: + """Deprecation warning wrapper.""" + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) From d8e5631567f8f007735b58e5297554a338c69e8f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 22:26:49 +0200 Subject: [PATCH 29/32] Extended changelog, slightly improved structure --- CHANGELOG.md | 90 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 55 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33622ce0f..d74e9900b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,31 +3,48 @@ ## Release 1.1.0 ### Api Extensions -- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 -- `Collector`s can now be closed, and their reset is more granular. #1063 -- Trainers can control whether collectors should be reset prior to training. #1063 -- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 -- `SamplingConfig` supports `batch_size=None`. #1077 -- Batch received new methods: `to_numpy_` and `to_torch_`. #1098, #1117 -- `to_dict` in Batch supports also non-recursive conversion. #1098 -- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098 -- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105. -- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074 -- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074 -- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074 -- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074 -- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). -Launchers for parallelization currently in alpha state. #1074 +- `data`: + - `Batch`: + - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098 + - Add methods `to_numpy_` and `to_torch_`. #1098, #1117 + - Add `__eq__` (semantic equality check). #1098 + - `keys()` deprecated in favor of `get_keys()` (needed to make iteration consistent with naming) #1105. + - `data.collector`: + - `Collector`: + - Introduced `BaseCollector` as a base class for all collectors. #1123 + - Add method `close` #1063 + - Method `reset` is now more granular (new flags controlling behavior). #1063 + - `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063 +- `trainer`: + - Trainers can now control whether collectors should be reset prior to training. #1063 +- policy: + - introduced attribute `in_training_step` that is controlled by the trainer. #1123 + - policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123 +- `highlevel`: + - `SamplingConfig`: + - Add support for `batch_size=None`. #1077 + - Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074 + - `highlevel.experiment`: + - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and + which determines the default run name and therefore the persistence subdirectory. + It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than + `experiment_name` (although the latter will still be interpreted correctly). #1074 #1131 + - Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131 + - `ExperimentBuilder`: + - Add method `build_seeded_collection` for the sound creation of multiple + experiments with varying random seeds #1131 + - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 +- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 + - The module `evaluation.launchers` for parallelization is currently in alpha state. - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 -- `continuous.Critic`: - - Add flag `apply_preprocess_net_to_obs_only` to allow the - preprocessing network to be applied to the observations only (without - the actions concatenated), which is essential for the case where we want - to reuse the actor's preprocessing network #1128 -- Base class for collectors: `BaseCollector` #1122 -- Collectors can now explicitly specify whether to use the policy in training or evaluation mode. #1122 -- New util context managers `in_eval_mode` and `in_train_mode` for torch modules. #1122 -- `reset` of `Collectors` now returns `obs` and `info`. #1122 +- `utils`: + - `net.continuous.Critic`: + - Add flag `apply_preprocess_net_to_obs_only` to allow the + preprocessing network to be applied to the observations only (without + the actions concatenated), which is essential for the case where we want + to reuse the actor's preprocessing network #1128 + - `torch_utils` (new module) + - Added contextmanagers `in` ### Fixes - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, @@ -52,25 +69,28 @@ instead of just `nn.Module`. #1032 - tests and examples are covered by `mypy`. #1077 - `NetBase` is more used, stricter typing by making it generic. #1077 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 -- Removed all `if __name__ == "__main__":` blocks from tests. #1122 -- Improved typing issues in tests with buffer and collector. #1122 ### Breaking Changes - -- Removed `.data` attribute from `Collector` and its child classes. #1063 -- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` -expicitly or pass `reset_before_collect=True` . #1063 +- `data`: + - `Collector`: + - Removed `.data` attribute. #1063 + - Collectors no longer reset the environment on initialization. + Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 + - Removed `no_grad` argument from `collect` method (was unused in tianshou). #1123 + - `Batch`: + - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. + Can be considered a bugfix. #1063 + - The methods `to_numpy` and `to_torch` in are not in-place anymore + (use `to_numpy_` or `to_torch_` instead). #1098, #1117 +- Logging: + - `BaseLogger.prepare_dict_for_logging` is now abstract. #1074 + - Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 -- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 -- The methods `to_numpy` and `to_torch` in `Batch` is not in-place anymore (use `to_numpy_` or `to_torch_` instead). #1098, #1117 - `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074 - `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074 -- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074 -- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074 -- Removed deprecations of `0.5.1` (will likely not affect anyone) and the unused `warnings` module. #1122 ### Tests From 6a5b3c837ad16bbb7efbfbbf3c9b7e2115278db0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 23:31:20 +0200 Subject: [PATCH 30/32] Docstrings, skip hidden files in autogen_rst --- docs/autogen_rst.py | 4 +++- tianshou/data/collector.py | 1 + tianshou/trainer/base.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index d3d4f080c..b1a8b18d9 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -74,7 +74,9 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" subdir_refs = [ f"{f}/index" for f in files_in_dir - if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_") + if os.path.isdir(os.path.join(src_root, f)) + and not f.startswith("_") + and not f.startswith(".") ] package_index_rst_path = os.path.join( rst_root, diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d2ca2178c..5bce6c0a7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -286,6 +286,7 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. + .. note:: One and only one collection number specification is permitted, either diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 3ad80c8c8..b738df551 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -417,7 +417,7 @@ def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: """Perform one training iteration. A training iteration includes collecting data (for online RL), determining whether to stop training, - and peforming a policy update if the training iteration should continue. + and performing a policy update if the training iteration should continue. :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. From 78ea0139568005138c95e21aa21e6b55cb4decb6 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 6 May 2024 16:16:20 +0200 Subject: [PATCH 31/32] Tests: fixed test_psrl.py: use args.reward_threshold instead of spec For some reason now env.spec.reward_treshold is None - some change in upstream code Also added better pytest skip message --- test/modelbased/test_psrl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 41849992d..995aef698 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -44,7 +44,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +@pytest.mark.skipif( + envpool is None, + reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)", +) def test_psrl(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) @@ -115,4 +118,4 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, test_in_train=False, ).run() - assert result.best_reward >= env.spec.reward_threshold + assert result.best_reward >= args.reward_threshold From e94a5c04cf93085c3df18e2d623f7ba465d20489 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 6 May 2024 16:50:48 +0200 Subject: [PATCH 32/32] New context manager: policy_within_training_step Adjusted notebooks, log messages and docs accordingly. Removed now obsolete in_eval_mode and the private context manager in Trainer --- CHANGELOG.md | 2 +- docs/02_notebooks/L0_overview.ipynb | 4 +-- docs/02_notebooks/L4_Policy.ipynb | 29 ++++++++++--------- docs/02_notebooks/L6_Trainer.ipynb | 44 ++++++++++++++++++++--------- test/base/test_utils.py | 6 ++-- tianshou/data/collector.py | 4 +-- tianshou/policy/base.py | 6 ++-- tianshou/trainer/base.py | 18 ++++-------- tianshou/utils/torch_utils.py | 33 +++++++++++++++------- 9 files changed, 85 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d74e9900b..807a9dad6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,7 +44,7 @@ the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network #1128 - `torch_utils` (new module) - - Added contextmanagers `in` + - Added context managers `torch_train_mode` and `policy_within_training_step` #1123 ### Fixes - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 59d6fd207..0ce6df154 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -18,9 +18,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "# !pip install tianshou gym" diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 00f7f27b9..eed8ea344 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -74,7 +74,8 @@ ")\n", "from tianshou.utils import RunningMeanStd\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" ] }, { @@ -644,7 +645,10 @@ "source": [ "obs, info = env.reset()\n", "for i in range(3, 10):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", + " # For retrieving actions to be used for training, we set the policy to training mode,\n", + " # but the wrapped torch module should be in eval mode.\n", + " with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n", + " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", " obs_next, rew, _, truncated, info = env.step(act)\n", " # pretend this episode never end\n", " terminated = False\n", @@ -695,7 +699,11 @@ }, "source": [ "#### Updates\n", - "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train." + "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n", + "\n", + "However, we need to manually set the torch module to training mode prior to that, \n", + "and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n", + "but users need to consider it when calling `.update` outside of the trainer." ] }, { @@ -711,16 +719,11 @@ "outputs": [], "source": [ "# 0 means sample all data from the buffer\n", - "policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "enqlFQLSJrQl" - }, - "source": [ - "Not that difficult, right?" + "\n", + "# For updating the policy, the policy should be in training mode\n", + "# and the wrapped torch module should also be in training mode (unlike when collecting data).\n", + "with policy_within_training_step(policy), torch_train_mode(policy):\n", + " policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" ] }, { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index 75aea471c..d5423bd01 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -54,7 +54,6 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": { "editable": true, "id": "do-xZ-8B7nVH", @@ -64,9 +63,12 @@ "tags": [ "hide-cell", "remove-output" - ] + ], + "ExecuteTime": { + "end_time": "2024-05-06T15:34:02.969675Z", + "start_time": "2024-05-06T15:34:00.747309Z" + } }, - "outputs": [], "source": [ "%%capture\n", "\n", @@ -78,14 +80,20 @@ "from tianshou.policy import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" - ] + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-06T15:34:07.536452Z", + "start_time": "2024-05-06T15:34:03.636670Z" + } + }, "source": [ "train_env_num = 4\n", "buffer_size = (\n", @@ -123,7 +131,9 @@ "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "test_collector = Collector(policy, test_envs)\n", "train_collector = Collector(policy, train_envs, replayBuffer)" - ] + ], + "outputs": [], + "execution_count": 2 }, { "cell_type": "markdown", @@ -154,11 +164,19 @@ "\n", "n_episode = 10\n", "for _i in range(n_episode):\n", - " evaluation_result = test_collector.collect(n_episode=n_episode)\n", + " # for test collector, we set the wrapped torch module to evaluation mode\n", + " # by default, the policy object itself is not within the training step\n", + " with torch_train_mode(policy, enabled=False):\n", + " evaluation_result = test_collector.collect(n_episode=n_episode)\n", " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", + " # for collecting data for training, the policy object should be within the training step\n", + " # (affecting e.g. whether the policy is stochastic or deterministic)\n", + " with policy_within_training_step(policy):\n", + " train_collector.collect(n_step=2000)\n", + " # 0 means taking all data stored in train_collector.buffer\n", + " # for updating the policy, the wrapped torch module should be in training mode\n", + " with torch_train_mode(policy):\n", + " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", " train_collector.reset_buffer(keep_statistics=True)" ] }, diff --git a/test/base/test_utils.py b/test/base/test_utils.py index ac3b2fa4d..f8e5938cb 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -6,7 +6,7 @@ from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic -from tianshou.utils.torch_utils import in_eval_mode, in_train_mode +from tianshou.utils.torch_utils import torch_train_mode def test_noise() -> None: @@ -137,7 +137,7 @@ def test_lr_schedulers() -> None: def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() - with in_eval_mode(module): + with torch_train_mode(module, False): assert not module.training assert module.training @@ -145,6 +145,6 @@ def test_in_eval_mode() -> None: def test_in_train_mode() -> None: module = nn.Linear(3, 4) module.eval() - with in_train_mode(module): + with torch_train_mode(module): assert module.training assert not module.training diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5bce6c0a7..6773a6383 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -27,7 +27,7 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin -from tianshou.utils.torch_utils import in_eval_mode +from tianshou.utils.torch_utils import torch_train_mode log = logging.getLogger(__name__) @@ -300,7 +300,7 @@ def collect( if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - with in_eval_mode(self.policy): # safety precaution only + with torch_train_mode(self.policy, False): return self._collect( n_step=n_step, n_episode=n_episode, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index bee9f9bdf..b7ae5f23d 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,7 +25,7 @@ ) from tianshou.utils import MultipleLRSchedulers from tianshou.utils.print import DataclassPPrintMixin -from tianshou.utils.torch_utils import in_train_mode +from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode logger = logging.getLogger(__name__) @@ -532,7 +532,7 @@ def update( raise RuntimeError( f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " - f"flag yourself.", + f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", ) if buffer is None: @@ -541,7 +541,7 @@ def update( batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - with in_train_mode(self): + with torch_train_mode(self): training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b738df551..242f2b028 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -2,8 +2,7 @@ import time from abc import ABC, abstractmethod from collections import defaultdict, deque -from collections.abc import Callable, Iterator -from contextlib import contextmanager +from collections.abc import Callable from dataclasses import asdict import numpy as np @@ -29,6 +28,7 @@ tqdm_config, ) from tianshou.utils.logging import set_numerical_fields_to_precision +from tianshou.utils.torch_utils import policy_within_training_step log = logging.getLogger(__name__) @@ -404,15 +404,6 @@ def test_step(self) -> tuple[CollectStats, bool]: return test_stat, stop_fn_flag - @contextmanager - def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]: - old_value = self.policy.is_within_training_step - try: - self.policy.is_within_training_step = is_within_training_step - yield - finally: - self.policy.is_within_training_step = old_value - def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: """Perform one training iteration. @@ -422,7 +413,7 @@ def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. """ - with self._is_within_training_step_enabled(True): + with policy_within_training_step(self.policy): should_stop_training = False collect_stats: CollectStatsBase | CollectStats @@ -474,6 +465,7 @@ def _collect_training_data(self) -> CollectStats: return collect_stats + # TODO (maybe): separate out side effect, simplify name? def _update_best_reward_and_return_should_stop_training( self, collect_stats: CollectStats, @@ -492,7 +484,7 @@ def _update_best_reward_and_return_should_stop_training( should_stop_training = False # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics - with self._is_within_training_step_enabled(False): + with policy_within_training_step(self.policy, enabled=False): if ( collect_stats.n_collected_episodes > 0 and self.test_in_train diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 2fb70dad2..430d174e7 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -1,26 +1,39 @@ from collections.abc import Iterator from contextlib import contextmanager +from typing import TYPE_CHECKING from torch import nn +if TYPE_CHECKING: + from tianshou.policy import BasePolicy + @contextmanager -def in_eval_mode(module: nn.Module) -> Iterator[None]: - """Temporarily switch to evaluation mode.""" - train = module.training +def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`.""" + original_mode = module.training try: - module.eval() + module.train(enabled) yield finally: - module.train(train) + module.train(original_mode) @contextmanager -def in_train_mode(module: nn.Module) -> Iterator[None]: - """Temporarily switch to training mode.""" - train = module.training +def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `policy.is_within_training_step=enabled`. + + Enabling this ensures that the policy is able to adapt its behavior, + allowing it to differentiate between training and inference/evaluation, + e.g., to sample actions instead of using the most probable action (where applicable) + Note that for rollout, which also happens within a training step, one would usually want + the wrapped torch module to be in evaluation mode, which can be achieved using + `with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both + within training step and in torch train mode. + """ + original_mode = policy.is_within_training_step try: - module.train() + policy.is_within_training_step = enabled yield finally: - module.train(train) + policy.is_within_training_step = original_mode