From 863011f1a7cd7111376410c9387a88bb78990261 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 7 Jan 2022 05:08:10 +0800 Subject: [PATCH 01/14] implement Intrinsic Curiosity Module --- README.md | 1 + docs/api/tianshou.policy.rst | 5 + docs/index.rst | 1 + examples/atari/atari_dqn.py | 44 ++++++- test/modelbased/test_dqn_icm.py | 204 ++++++++++++++++++++++++++++++ test/modelbased/test_ppo_icm.py | 192 ++++++++++++++++++++++++++++ tianshou/policy/__init__.py | 2 + tianshou/policy/modelbased/icm.py | 141 +++++++++++++++++++++ tianshou/utils/net/discrete.py | 59 ++++++++- 9 files changed, 646 insertions(+), 3 deletions(-) create mode 100644 test/modelbased/test_dqn_icm.py create mode 100644 test/modelbased/test_ppo_icm.py create mode 100644 tianshou/policy/modelbased/icm.py diff --git a/README.md b/README.md index 13cfc191f..345947826 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) +- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf) Here is Tianshou's other features: diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index 7292afdcc..b44598a2f 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -137,6 +137,11 @@ Model-based :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.ICMPolicy + :members: + :undoc-members: + :show-inheritance: + Multi-agent ----------- diff --git a/docs/index.rst b/docs/index.rst index a7fa0da26..7341ec0f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ +* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2cd26f824..eb32fb9ed 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -11,8 +11,11 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DQNPolicy +from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import MLP +from tianshou.utils.net.discrete import IntrinsicCuriosityModule def get_args(): @@ -55,6 +58,24 @@ def get_args(): help='watch the play of pre-trained policy only' ) parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument( + '--icm-lr-scale', + type=float, + default=0., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--icm-reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--icm-forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) return parser.parse_args() @@ -101,6 +122,24 @@ def test_dqn(args=get_args()): args.n_step, target_update_freq=args.target_update_freq ) + if args.icm_lr_scale > 0: + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=[512], + device=args.device + ) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, + args.icm_forward_loss_weight + ).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) @@ -118,7 +157,8 @@ def test_dqn(args=get_args()): train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log - log_path = os.path.join(args.logdir, args.task, 'dqn') + log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn' + log_path = os.path.join(args.logdir, args.task, log_name) if args.logger == "tensorboard": writer = SummaryWriter(log_path) writer.add_text("args", str(args)) @@ -127,7 +167,7 @@ def test_dqn(args=get_args()): logger = WandbLogger( save_interval=1, project=args.task, - name='dqn', + name=log_name, run_id=args.resume_id, config=args, ) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py new file mode 100644 index 000000000..8909ce5f7 --- /dev/null +++ b/test/modelbased/test_dqn_icm.py @@ -0,0 +1,204 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import DQNPolicy, ICMPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net, MLP +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument( + '--lr-scale', + type=float, + default=1., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + args = parser.parse_known_args()[0] + return args + + +def test_dqn_icm(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + # dueling=(Q_param, V_param), + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) + feature_dim = args.hidden_sizes[-1] + feature_net = MLP( + np.prod(args.state_shape), + output_dim=feature_dim, + hidden_sizes=args.hidden_sizes[:-1], + device=args.device + ) + action_dim = np.prod(args.action_shape) + icm_net = IntrinsicCuriosityModule( + feature_net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes[-1:], + device=args.device + ).to(args.device) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.lr_scale, args.reward_scale, + args.forward_loss_weight + ) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta, + ) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'dqn_icm') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_pdqn_icm(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + args.seed = 1 + test_dqn_icm(args) + + +if __name__ == '__main__': + test_dqn_icm(get_args()) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py new file mode 100644 index 000000000..cc05cdc31 --- /dev/null +++ b/test/modelbased/test_ppo_icm.py @@ -0,0 +1,192 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import PPOPolicy, ICMPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net, MLP +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--step-per-collect', type=int, default=2000) + parser.add_argument('--repeat-per-collect', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=20) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + # ppo special + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--ent-coef', type=float, default=0.0) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--rew-norm', type=int, default=0) + parser.add_argument('--norm-adv', type=int, default=0) + parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=0) + parser.add_argument( + '--lr-scale', + type=float, + default=1., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + args = parser.parse_known_args()[0] + return args + + +def test_ppo(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + if torch.cuda.is_available(): + actor = DataParallelNet( + Actor(net, args.action_shape, device=None).to(args.device) + ) + critic = DataParallelNet(Critic(net, device=None).to(args.device)) + else: + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) + actor_critic = ActorCritic(actor, critic) + # orthogonal initialization + for m in actor_critic.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + dist = torch.distributions.Categorical + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + dual_clip=args.dual_clip, + value_clip=args.value_clip, + action_space=env.action_space, + deterministic_eval=True, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ) + feature_dim = args.hidden_sizes[-1] + feature_net = MLP( + np.prod(args.state_shape), + output_dim=feature_dim, + hidden_sizes=args.hidden_sizes[:-1], + device=args.device + ) + action_dim = np.prod(args.action_shape) + icm_net = IntrinsicCuriosityModule( + feature_net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes[-1:], + device=args.device + ).to(args.device) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.lr_scale, args.reward_scale, + args.forward_loss_weight + ) + # collector + train_collector = Collector( + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'ppo_icm') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_ppo() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 174762e25..f8c84416d 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -24,6 +24,7 @@ from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy +from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager __all__ = [ @@ -50,5 +51,6 @@ "DiscreteCQLPolicy", "DiscreteCRRPolicy", "PSRLPolicy", + "ICMPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py new file mode 100644 index 000000000..d1aed89a6 --- /dev/null +++ b/tianshou/policy/modelbased/icm.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.policy import BasePolicy +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +class ICMPolicy(BasePolicy): + """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. + + :param BasePolicy policy: a base policy to add ICM. + :param IntrinsicCuriosityModule model: the ICM model. + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float lr_scale: the scaling factor for ICM learning. + :param float forward_loss_weight: the weight for forward model loss. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + policy: BasePolicy, + model: IntrinsicCuriosityModule, + optim: torch.optim.Optimizer, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.policy = policy + self.model = model + self.optim = optim + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + + def train(self, mode: bool = True) -> BasePolicy: + """Set the module in training mode.""" + self.policy.train(mode) + self.training = mode + self.model.train(mode) + return self + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data. + + If you need to mask the action, please add a "mask" into batch.obs, for + example, if we have an environment that has "0/1/2" three actions: + :: + + batch == Batch( + obs=Batch( + obs="original obs, with batch_size=1 for demonstration", + mask=np.array([[False, True, False]]), + # action 1 is available + # action 0 and 2 are unavailable + ), + ... + ) + + :param float eps: in [0, 1], for epsilon-greedy exploration method. + + :return: A :class:`~tianshou.data.Batch` which has 3 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``state`` the hidden state. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + return self.policy.forward(batch, state, **kwargs) + + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: + return self.policy.exploration_noise(act, batch) + + def set_eps(self, eps: float) -> None: + """Set the eps for epsilon-greedy exploration.""" + if hasattr(self.policy, "set_eps"): + self.policy.set_eps(eps) + else: + raise NotImplementedError() + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> Batch: + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more information. + """ + mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) + batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) + batch.rew += to_numpy(mse_loss * self.reward_scale) + return self.policy.process_fn(batch, buffer, indices) + + def post_process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> None: + """Post-process the data from the provided replay buffer. + + Typical usage is to update the sampling weight in prioritized + experience replay. Used in :meth:`update`. + """ + self.policy.post_process_fn(batch, buffer, indices) + batch.rew = batch.policy.orig_rew # restore original reward + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + res = self.policy.learn(batch, **kwargs) + self.optim.zero_grad() + act_hat = batch.policy.act_hat + act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + inverse_loss = F.cross_entropy(act_hat, act).mean() + forward_loss = batch.policy.mse_loss.mean() + loss = ( + (1 - self.forward_loss_weight) * inverse_loss + + self.forward_loss_weight * forward_loss + ) * self.lr_scale + loss.backward() + self.optim.step() + return { + "loss/policy": res["loss"], + "loss/icm": loss.item(), + "loss/icm/forward": forward_loss.item(), + "loss/icm/inverse": inverse_loss.item() + } diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index bcc6531e3..d6e3d6182 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch import nn -from tianshou.data import Batch +from tianshou.data import Batch, to_torch from tianshou.utils.net.common import MLP @@ -380,3 +380,60 @@ def sample_noise(model: nn.Module) -> bool: m.sample() done = True return done + + +class IntrinsicCuriosityModule(nn.Module): + """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. + + :param torch.nn.Module feature_net: a self-defined feature_net which output a + flattened hidden state. + :param int feature_dim: input dimension of the feature net. + :param int action_dim: dimension of the action space. + :param hidden_sizes: hidden layer sizes for forward and inverse models. + :param device: device for the module. + """ + + def __init__( + self, + feature_net: nn.Module, + feature_dim: int, + action_dim: int, + hidden_sizes: Sequence[int] = (), + device: Union[str, torch.device] = "cpu" + ) -> None: + super().__init__() + self.feature_net = feature_net + self.forward_model = MLP( + feature_dim + action_dim, + output_dim=feature_dim, + hidden_sizes=hidden_sizes, + device=device + ) + self.inverse_model = MLP( + feature_dim * 2, + output_dim=action_dim, + hidden_sizes=hidden_sizes, + device=device + ) + self.feature_dim = feature_dim + self.action_dim = action_dim + self.device = device + + def forward( # type: ignore + self, + s1: Union[np.ndarray, torch.Tensor], + act: Union[np.ndarray, torch.Tensor], + s2: Union[np.ndarray, torch.Tensor], + **kwargs: Any + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" + s1 = to_torch(s1, dtype=torch.float32, device=self.device) + s2 = to_torch(s2, dtype=torch.float32, device=self.device) + phi1, phi2 = self.feature_net(s1), self.feature_net(s2) + act = to_torch(act, dtype=torch.long, device=self.device) + phi2_hat = self.forward_model( + torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1) + ) + mse_loss = 0.5 * ((phi2_hat - phi2)**2).sum(1) + act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1)) + return mse_loss, act_hat From b10856844119bdedabda4eea1d3a9164f763404c Mon Sep 17 00:00:00 2001 From: Yi Su Date: Wed, 12 Jan 2022 04:00:41 +0800 Subject: [PATCH 02/14] make linter happy --- examples/atari/atari_dqn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index eb32fb9ed..685ebb40b 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -14,7 +14,6 @@ from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import MLP from tianshou.utils.net.discrete import IntrinsicCuriosityModule From 8261b85885790127d89e6f0cffbba1095425f318 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Tue, 11 Jan 2022 12:52:47 -0800 Subject: [PATCH 03/14] update by isort --- test/modelbased/test_dqn_icm.py | 2 +- test/modelbased/test_ppo_icm.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 8909ce5f7..1fcd60f09 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -12,7 +12,7 @@ from tianshou.policy import DQNPolicy, ICMPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net, MLP +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index cc05cdc31..1c202f907 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -9,10 +9,10 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.policy import PPOPolicy, ICMPolicy +from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net, MLP +from tianshou.utils.net.common import MLP, ActorCritic, DataParallelNet, Net from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule From 27889743eab3dd6a822e3e11a38a3b34d9d6594a Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 14 Jan 2022 09:25:08 +0800 Subject: [PATCH 04/14] add vizdoom a2c+icm example --- examples/vizdoom/vizdoom_a2c_icm.py | 230 ++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 examples/vizdoom/vizdoom_a2c_icm.py diff --git a/examples/vizdoom/vizdoom_a2c_icm.py b/examples/vizdoom/vizdoom_a2c_icm.py new file mode 100644 index 000000000..a2737fc05 --- /dev/null +++ b/examples/vizdoom/vizdoom_a2c_icm.py @@ -0,0 +1,230 @@ +import argparse +import os +import pprint + +import numpy as np +import torch +from env import Env +from network import DQN +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import ShmemVectorEnv +from tianshou.policy import A2CPolicy, ICMPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.net.common import ActorCritic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='D2_navigation') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=2000000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=300) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--episode-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--update-per-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--skip-num', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + parser.add_argument( + '--save-lmp', + default=False, + action='store_true', + help='save lmp file for replay whole episode' + ) + parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument( + '--icm-lr-scale', + type=float, + default=0., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--icm-reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--icm-forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + return parser.parse_args() + + +def test_a2c(args=get_args()): + args.cfg_path = f"maps/{args.task}.cfg" + args.wad_path = f"maps/{args.task}.wad" + args.res = (args.skip_num, 84, 84) + env = Env(args.cfg_path, args.frames_stack, args.res) + args.state_shape = args.res + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = ShmemVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res) + for _ in range(args.training_num) + ] + ) + test_envs = ShmemVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) + for _ in range(min(os.cpu_count() - 1, args.test_num)) + ] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = DQN( + *args.state_shape, args.action_shape, device=args.device, features_only=True + ) + actor = Actor( + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device + ) + critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device) + optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + # define policy + dist = torch.distributions.Categorical + policy = A2CPolicy(actor, critic, optim, dist).to(args.device) + if args.icm_lr_scale > 0: + feature_net = DQN( + *args.state_shape, + args.action_shape, + device=args.device, + features_only=True + ) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device + ) + icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, + args.icm_forward_loss_weight + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_path = os.path.join(args.logdir, args.task, 'a2c') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) + rew = result["rews"].mean() + lens = result["lens"].mean() * args.skip_num + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_a2c(get_args()) From 627ac7b35e7fb01c66eb1db1c5db82ef8512200e Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 14 Jan 2022 09:25:22 +0800 Subject: [PATCH 05/14] address review comments --- tianshou/policy/modelbased/icm.py | 34 +++++++++---------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index d1aed89a6..d41a53570 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -12,7 +12,7 @@ class ICMPolicy(BasePolicy): """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. - :param BasePolicy policy: a base policy to add ICM. + :param BasePolicy policy: a base policy to add ICM to. :param IntrinsicCuriosityModule model: the ICM model. :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float lr_scale: the scaling factor for ICM learning. @@ -57,23 +57,7 @@ def forward( ) -> Batch: """Compute action over the given batch data. - If you need to mask the action, please add a "mask" into batch.obs, for - example, if we have an environment that has "0/1/2" three actions: - :: - - batch == Batch( - obs=Batch( - obs="original obs, with batch_size=1 for demonstration", - mask=np.array([[False, True, False]]), - # action 1 is available - # action 0 and 2 are unavailable - ), - ... - ) - - :param float eps: in [0, 1], for epsilon-greedy exploration method. - - :return: A :class:`~tianshou.data.Batch` which has 3 keys: + :return: A :class:`~tianshou.data.Batch` which has at least 3 keys: * ``act`` the action. * ``logits`` the network's raw output. @@ -133,9 +117,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: ) * self.lr_scale loss.backward() self.optim.step() - return { - "loss/policy": res["loss"], - "loss/icm": loss.item(), - "loss/icm/forward": forward_loss.item(), - "loss/icm/inverse": inverse_loss.item() - } + res.update( + { + "loss/icm": loss.item(), + "loss/icm/forward": forward_loss.item(), + "loss/icm/inverse": inverse_loss.item() + } + ) + return res From 34f7a2883cd471bc17d43098d81b244749dd0fc9 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 14 Jan 2022 10:03:24 +0800 Subject: [PATCH 06/14] minor change --- tianshou/utils/net/discrete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index d6e3d6182..7d490bbd3 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -434,6 +434,6 @@ def forward( # type: ignore phi2_hat = self.forward_model( torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1) ) - mse_loss = 0.5 * ((phi2_hat - phi2)**2).sum(1) + mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1) act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1)) return mse_loss, act_hat From 0aea23abff5ed40437d5463380120f9ac43daf82 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Thu, 13 Jan 2022 18:05:08 -0800 Subject: [PATCH 07/14] update by isort --- examples/vizdoom/vizdoom_a2c_icm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vizdoom/vizdoom_a2c_icm.py b/examples/vizdoom/vizdoom_a2c_icm.py index a2737fc05..fc065aec9 100644 --- a/examples/vizdoom/vizdoom_a2c_icm.py +++ b/examples/vizdoom/vizdoom_a2c_icm.py @@ -13,8 +13,8 @@ from tianshou.policy import A2CPolicy, ICMPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule def get_args(): From 4cceb16494eb9a738f904bdd06ef047fc285d526 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Thu, 13 Jan 2022 19:18:56 -0800 Subject: [PATCH 08/14] make mypy happy --- tianshou/policy/modelbased/icm.py | 8 +++++--- tianshou/utils/net/discrete.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index d41a53570..96f8b11af 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -42,7 +42,7 @@ def __init__( self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight - def train(self, mode: bool = True) -> BasePolicy: + def train(self, mode: bool = True) -> "ICMPolicy": """Set the module in training mode.""" self.policy.train(mode) self.training = mode @@ -77,7 +77,7 @@ def exploration_noise(self, act: Union[np.ndarray, Batch], def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" if hasattr(self.policy, "set_eps"): - self.policy.set_eps(eps) + self.policy.set_eps(eps) # type: ignore else: raise NotImplementedError() @@ -108,7 +108,9 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: res = self.policy.learn(batch, **kwargs) self.optim.zero_grad() act_hat = batch.policy.act_hat - act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + act: torch.Tensor = to_torch( + batch.act, dtype=torch.long, device=act_hat.device + ) inverse_loss = F.cross_entropy(act_hat, act).mean() forward_loss = batch.policy.mse_loss.mean() loss = ( diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 2a60a50f3..db2749ae0 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -431,7 +431,7 @@ def __init__( self.action_dim = action_dim self.device = device - def forward( # type: ignore + def forward( self, s1: Union[np.ndarray, torch.Tensor], act: Union[np.ndarray, torch.Tensor], From 6b721f84fa3354244c013bd927088782b0bb831e Mon Sep 17 00:00:00 2001 From: Yi Su Date: Thu, 13 Jan 2022 19:59:45 -0800 Subject: [PATCH 09/14] reformat code --- tianshou/utils/net/discrete.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index db2749ae0..f79a9d1b0 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -432,11 +432,9 @@ def __init__( self.device = device def forward( - self, - s1: Union[np.ndarray, torch.Tensor], - act: Union[np.ndarray, torch.Tensor], - s2: Union[np.ndarray, torch.Tensor], - **kwargs: Any + self, s1: Union[np.ndarray, torch.Tensor], + act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray, + torch.Tensor], **kwargs: Any ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" s1 = to_torch(s1, dtype=torch.float32, device=self.device) From 286a8592cf711dbc9a8ef089f5066c7ec4e5d575 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Thu, 13 Jan 2022 21:19:24 -0800 Subject: [PATCH 10/14] make mypy happy --- tianshou/policy/modelbased/icm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 96f8b11af..ba3095857 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -108,9 +108,9 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: res = self.policy.learn(batch, **kwargs) self.optim.zero_grad() act_hat = batch.policy.act_hat - act: torch.Tensor = to_torch( + act = to_torch( batch.act, dtype=torch.long, device=act_hat.device - ) + ) # type: ignore inverse_loss = F.cross_entropy(act_hat, act).mean() forward_loss = batch.policy.mse_loss.mean() loss = ( From 3deb0766bbd4c772899b5834677ca22579a85cb0 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Thu, 13 Jan 2022 21:29:26 -0800 Subject: [PATCH 11/14] make mypy happy --- tianshou/policy/modelbased/icm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index ba3095857..96ebbe9f1 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -108,10 +108,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: res = self.policy.learn(batch, **kwargs) self.optim.zero_grad() act_hat = batch.policy.act_hat - act = to_torch( - batch.act, dtype=torch.long, device=act_hat.device - ) # type: ignore - inverse_loss = F.cross_entropy(act_hat, act).mean() + act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore forward_loss = batch.policy.mse_loss.mean() loss = ( (1 - self.forward_loss_weight) * inverse_loss + From 9db54b5efcf9ee99a9729ae186845cb179624b47 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Fri, 14 Jan 2022 05:17:03 -0800 Subject: [PATCH 12/14] Apply suggestions from code review --- examples/vizdoom/vizdoom_a2c_icm.py | 8 +------- tianshou/policy/modelbased/icm.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/examples/vizdoom/vizdoom_a2c_icm.py b/examples/vizdoom/vizdoom_a2c_icm.py index fc065aec9..1d84148aa 100644 --- a/examples/vizdoom/vizdoom_a2c_icm.py +++ b/examples/vizdoom/vizdoom_a2c_icm.py @@ -162,13 +162,7 @@ def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: - return mean_rewards >= 20 - else: - return False - + return False # watch agent's performance def watch(): print("Setup test envs ...") diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 96ebbe9f1..5a723c10b 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -55,13 +55,7 @@ def forward( state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which has at least 3 keys: - - * ``act`` the action. - * ``logits`` the network's raw output. - * ``state`` the hidden state. + """Compute action over the given batch data by inner policy. .. seealso:: From 821a0cd7e4234022ef3473016d8266d2e0b659ae Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Fri, 14 Jan 2022 08:18:42 -0500 Subject: [PATCH 13/14] fix ci --- docs/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 8f6526e8d..7a4a3f7d8 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -135,3 +135,4 @@ Huayu Strens Ornstein Uhlenbeck +mse From e035013a1b2a57a1aad45faff9a613ce9505181f Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Fri, 14 Jan 2022 08:22:38 -0500 Subject: [PATCH 14/14] fix ci --- examples/vizdoom/vizdoom_a2c_icm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/vizdoom/vizdoom_a2c_icm.py b/examples/vizdoom/vizdoom_a2c_icm.py index 1d84148aa..99668ce16 100644 --- a/examples/vizdoom/vizdoom_a2c_icm.py +++ b/examples/vizdoom/vizdoom_a2c_icm.py @@ -163,8 +163,9 @@ def save_fn(policy): def stop_fn(mean_rewards): return False - # watch agent's performance + def watch(): + # watch agent's performance print("Setup test envs ...") policy.eval() test_envs.seed(args.seed)