From a59d96d04164d1442de56de3819a8eaff5d62bdd Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 14 Jan 2022 10:43:48 -0800 Subject: [PATCH] Add Intrinsic Curiosity Module (#503) --- README.md | 1 + docs/api/tianshou.policy.rst | 5 + docs/index.rst | 1 + docs/spelling_wordlist.txt | 1 + examples/atari/atari_dqn.py | 43 +++++- examples/vizdoom/vizdoom_a2c_icm.py | 225 ++++++++++++++++++++++++++++ test/modelbased/test_dqn_icm.py | 204 +++++++++++++++++++++++++ test/modelbased/test_ppo_icm.py | 192 ++++++++++++++++++++++++ tianshou/policy/__init__.py | 2 + tianshou/policy/modelbased/icm.py | 121 +++++++++++++++ tianshou/utils/net/discrete.py | 57 ++++++- 11 files changed, 849 insertions(+), 3 deletions(-) create mode 100644 examples/vizdoom/vizdoom_a2c_icm.py create mode 100644 test/modelbased/test_dqn_icm.py create mode 100644 test/modelbased/test_ppo_icm.py create mode 100644 tianshou/policy/modelbased/icm.py diff --git a/README.md b/README.md index 13cfc191f..345947826 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) +- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf) Here is Tianshou's other features: diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index 7292afdcc..b44598a2f 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -137,6 +137,11 @@ Model-based :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.ICMPolicy + :members: + :undoc-members: + :show-inheritance: + Multi-agent ----------- diff --git a/docs/index.rst b/docs/index.rst index a7fa0da26..7341ec0f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ +* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 8f6526e8d..7a4a3f7d8 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -135,3 +135,4 @@ Huayu Strens Ornstein Uhlenbeck +mse diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2cd26f824..685ebb40b 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -11,8 +11,10 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DQNPolicy +from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.discrete import IntrinsicCuriosityModule def get_args(): @@ -55,6 +57,24 @@ def get_args(): help='watch the play of pre-trained policy only' ) parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument( + '--icm-lr-scale', + type=float, + default=0., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--icm-reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--icm-forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) return parser.parse_args() @@ -101,6 +121,24 @@ def test_dqn(args=get_args()): args.n_step, target_update_freq=args.target_update_freq ) + if args.icm_lr_scale > 0: + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=[512], + device=args.device + ) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, + args.icm_forward_loss_weight + ).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) @@ -118,7 +156,8 @@ def test_dqn(args=get_args()): train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log - log_path = os.path.join(args.logdir, args.task, 'dqn') + log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn' + log_path = os.path.join(args.logdir, args.task, log_name) if args.logger == "tensorboard": writer = SummaryWriter(log_path) writer.add_text("args", str(args)) @@ -127,7 +166,7 @@ def test_dqn(args=get_args()): logger = WandbLogger( save_interval=1, project=args.task, - name='dqn', + name=log_name, run_id=args.resume_id, config=args, ) diff --git a/examples/vizdoom/vizdoom_a2c_icm.py b/examples/vizdoom/vizdoom_a2c_icm.py new file mode 100644 index 000000000..99668ce16 --- /dev/null +++ b/examples/vizdoom/vizdoom_a2c_icm.py @@ -0,0 +1,225 @@ +import argparse +import os +import pprint + +import numpy as np +import torch +from env import Env +from network import DQN +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import ShmemVectorEnv +from tianshou.policy import A2CPolicy, ICMPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='D2_navigation') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=2000000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=300) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--episode-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--update-per-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--skip-num', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + parser.add_argument( + '--save-lmp', + default=False, + action='store_true', + help='save lmp file for replay whole episode' + ) + parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument( + '--icm-lr-scale', + type=float, + default=0., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--icm-reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--icm-forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + return parser.parse_args() + + +def test_a2c(args=get_args()): + args.cfg_path = f"maps/{args.task}.cfg" + args.wad_path = f"maps/{args.task}.wad" + args.res = (args.skip_num, 84, 84) + env = Env(args.cfg_path, args.frames_stack, args.res) + args.state_shape = args.res + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = ShmemVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res) + for _ in range(args.training_num) + ] + ) + test_envs = ShmemVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) + for _ in range(min(os.cpu_count() - 1, args.test_num)) + ] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = DQN( + *args.state_shape, args.action_shape, device=args.device, features_only=True + ) + actor = Actor( + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device + ) + critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device) + optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + # define policy + dist = torch.distributions.Categorical + policy = A2CPolicy(actor, critic, optim, dist).to(args.device) + if args.icm_lr_scale > 0: + feature_net = DQN( + *args.state_shape, + args.action_shape, + device=args.device, + features_only=True + ) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device + ) + icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, + args.icm_forward_loss_weight + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_path = os.path.join(args.logdir, args.task, 'a2c') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return False + + def watch(): + # watch agent's performance + print("Setup test envs ...") + policy.eval() + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) + rew = result["rews"].mean() + lens = result["lens"].mean() * args.skip_num + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_a2c(get_args()) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py new file mode 100644 index 000000000..1fcd60f09 --- /dev/null +++ b/test/modelbased/test_dqn_icm.py @@ -0,0 +1,204 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import DQNPolicy, ICMPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument( + '--lr-scale', + type=float, + default=1., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + args = parser.parse_known_args()[0] + return args + + +def test_dqn_icm(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + # dueling=(Q_param, V_param), + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) + feature_dim = args.hidden_sizes[-1] + feature_net = MLP( + np.prod(args.state_shape), + output_dim=feature_dim, + hidden_sizes=args.hidden_sizes[:-1], + device=args.device + ) + action_dim = np.prod(args.action_shape) + icm_net = IntrinsicCuriosityModule( + feature_net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes[-1:], + device=args.device + ).to(args.device) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.lr_scale, args.reward_scale, + args.forward_loss_weight + ) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta, + ) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'dqn_icm') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_pdqn_icm(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + args.seed = 1 + test_dqn_icm(args) + + +if __name__ == '__main__': + test_dqn_icm(get_args()) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py new file mode 100644 index 000000000..1c202f907 --- /dev/null +++ b/test/modelbased/test_ppo_icm.py @@ -0,0 +1,192 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import MLP, ActorCritic, DataParallelNet, Net +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--step-per-collect', type=int, default=2000) + parser.add_argument('--repeat-per-collect', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=20) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + # ppo special + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--ent-coef', type=float, default=0.0) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--rew-norm', type=int, default=0) + parser.add_argument('--norm-adv', type=int, default=0) + parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=0) + parser.add_argument( + '--lr-scale', + type=float, + default=1., + help='use intrinsic curiosity module with this lr scale' + ) + parser.add_argument( + '--reward-scale', + type=float, + default=0.01, + help='scaling factor for intrinsic curiosity reward' + ) + parser.add_argument( + '--forward-loss-weight', + type=float, + default=0.2, + help='weight for the forward model loss in ICM' + ) + args = parser.parse_known_args()[0] + return args + + +def test_ppo(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + if torch.cuda.is_available(): + actor = DataParallelNet( + Actor(net, args.action_shape, device=None).to(args.device) + ) + critic = DataParallelNet(Critic(net, device=None).to(args.device)) + else: + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) + actor_critic = ActorCritic(actor, critic) + # orthogonal initialization + for m in actor_critic.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + dist = torch.distributions.Categorical + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + dual_clip=args.dual_clip, + value_clip=args.value_clip, + action_space=env.action_space, + deterministic_eval=True, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ) + feature_dim = args.hidden_sizes[-1] + feature_net = MLP( + np.prod(args.state_shape), + output_dim=feature_dim, + hidden_sizes=args.hidden_sizes[:-1], + device=args.device + ) + action_dim = np.prod(args.action_shape) + icm_net = IntrinsicCuriosityModule( + feature_net, + feature_dim, + action_dim, + hidden_sizes=args.hidden_sizes[-1:], + device=args.device + ).to(args.device) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy, icm_net, icm_optim, args.lr_scale, args.reward_scale, + args.forward_loss_weight + ) + # collector + train_collector = Collector( + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'ppo_icm') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_ppo() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 174762e25..f8c84416d 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -24,6 +24,7 @@ from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy +from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager __all__ = [ @@ -50,5 +51,6 @@ "DiscreteCQLPolicy", "DiscreteCRRPolicy", "PSRLPolicy", + "ICMPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py new file mode 100644 index 000000000..5a723c10b --- /dev/null +++ b/tianshou/policy/modelbased/icm.py @@ -0,0 +1,121 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.policy import BasePolicy +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +class ICMPolicy(BasePolicy): + """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. + + :param BasePolicy policy: a base policy to add ICM to. + :param IntrinsicCuriosityModule model: the ICM model. + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float lr_scale: the scaling factor for ICM learning. + :param float forward_loss_weight: the weight for forward model loss. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + policy: BasePolicy, + model: IntrinsicCuriosityModule, + optim: torch.optim.Optimizer, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.policy = policy + self.model = model + self.optim = optim + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + + def train(self, mode: bool = True) -> "ICMPolicy": + """Set the module in training mode.""" + self.policy.train(mode) + self.training = mode + self.model.train(mode) + return self + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data by inner policy. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + return self.policy.forward(batch, state, **kwargs) + + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: + return self.policy.exploration_noise(act, batch) + + def set_eps(self, eps: float) -> None: + """Set the eps for epsilon-greedy exploration.""" + if hasattr(self.policy, "set_eps"): + self.policy.set_eps(eps) # type: ignore + else: + raise NotImplementedError() + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> Batch: + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more information. + """ + mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) + batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) + batch.rew += to_numpy(mse_loss * self.reward_scale) + return self.policy.process_fn(batch, buffer, indices) + + def post_process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> None: + """Post-process the data from the provided replay buffer. + + Typical usage is to update the sampling weight in prioritized + experience replay. Used in :meth:`update`. + """ + self.policy.post_process_fn(batch, buffer, indices) + batch.rew = batch.policy.orig_rew # restore original reward + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + res = self.policy.learn(batch, **kwargs) + self.optim.zero_grad() + act_hat = batch.policy.act_hat + act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore + forward_loss = batch.policy.mse_loss.mean() + loss = ( + (1 - self.forward_loss_weight) * inverse_loss + + self.forward_loss_weight * forward_loss + ) * self.lr_scale + loss.backward() + self.optim.step() + res.update( + { + "loss/icm": loss.item(), + "loss/icm/forward": forward_loss.item(), + "loss/icm/inverse": inverse_loss.item() + } + ) + return res diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 844691c09..f79a9d1b0 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch import nn -from tianshou.data import Batch +from tianshou.data import Batch, to_torch from tianshou.utils.net.common import MLP @@ -392,3 +392,58 @@ def sample_noise(model: nn.Module) -> bool: m.sample() done = True return done + + +class IntrinsicCuriosityModule(nn.Module): + """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. + + :param torch.nn.Module feature_net: a self-defined feature_net which output a + flattened hidden state. + :param int feature_dim: input dimension of the feature net. + :param int action_dim: dimension of the action space. + :param hidden_sizes: hidden layer sizes for forward and inverse models. + :param device: device for the module. + """ + + def __init__( + self, + feature_net: nn.Module, + feature_dim: int, + action_dim: int, + hidden_sizes: Sequence[int] = (), + device: Union[str, torch.device] = "cpu" + ) -> None: + super().__init__() + self.feature_net = feature_net + self.forward_model = MLP( + feature_dim + action_dim, + output_dim=feature_dim, + hidden_sizes=hidden_sizes, + device=device + ) + self.inverse_model = MLP( + feature_dim * 2, + output_dim=action_dim, + hidden_sizes=hidden_sizes, + device=device + ) + self.feature_dim = feature_dim + self.action_dim = action_dim + self.device = device + + def forward( + self, s1: Union[np.ndarray, torch.Tensor], + act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray, + torch.Tensor], **kwargs: Any + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" + s1 = to_torch(s1, dtype=torch.float32, device=self.device) + s2 = to_torch(s2, dtype=torch.float32, device=self.device) + phi1, phi2 = self.feature_net(s1), self.feature_net(s2) + act = to_torch(act, dtype=torch.long, device=self.device) + phi2_hat = self.forward_model( + torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1) + ) + mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1) + act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1)) + return mse_loss, act_hat