Skip to content

Commit

Permalink
Add Intrinsic Curiosity Module (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
nuance1979 authored Jan 14, 2022
1 parent a2d76d1 commit a59d96d
Show file tree
Hide file tree
Showing 11 changed files with 849 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)

Here is Tianshou's other features:

Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ Model-based
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.ICMPolicy
:members:
:undoc-members:
:show-inheritance:

Multi-agent
-----------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,4 @@ Huayu
Strens
Ornstein
Uhlenbeck
mse
43 changes: 41 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import IntrinsicCuriosityModule


def get_args():
Expand Down Expand Up @@ -55,6 +57,24 @@ def get_args():
help='watch the play of pre-trained policy only'
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
)
parser.add_argument(
'--icm-reward-scale',
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
)
parser.add_argument(
'--icm-forward-loss-weight',
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
)
return parser.parse_args()


Expand Down Expand Up @@ -101,6 +121,24 @@ def test_dqn(args=get_args()):
args.n_step,
target_update_freq=args.target_update_freq
)
if args.icm_lr_scale > 0:
feature_net = DQN(
*args.state_shape, args.action_shape, args.device, features_only=True
)
action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.net,
feature_dim,
action_dim,
hidden_sizes=[512],
device=args.device
)
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
policy = ICMPolicy(
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
args.icm_forward_loss_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
Expand All @@ -118,7 +156,8 @@ def test_dqn(args=get_args()):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
log_path = os.path.join(args.logdir, args.task, log_name)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
Expand All @@ -127,7 +166,7 @@ def test_dqn(args=get_args()):
logger = WandbLogger(
save_interval=1,
project=args.task,
name='dqn',
name=log_name,
run_id=args.resume_id,
config=args,
)
Expand Down
225 changes: 225 additions & 0 deletions examples/vizdoom/vizdoom_a2c_icm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import argparse
import os
import pprint

import numpy as np
import torch
from env import Env
from network import DQN
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import A2CPolicy, ICMPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='D2_navigation')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=2000000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=300)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--episode-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--skip-num', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
parser.add_argument(
'--save-lmp',
default=False,
action='store_true',
help='save lmp file for replay whole episode'
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
)
parser.add_argument(
'--icm-reward-scale',
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
)
parser.add_argument(
'--icm-forward-loss-weight',
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
)
return parser.parse_args()


def test_a2c(args=get_args()):
args.cfg_path = f"maps/{args.task}.cfg"
args.wad_path = f"maps/{args.task}.wad"
args.res = (args.skip_num, 84, 84)
env = Env(args.cfg_path, args.frames_stack, args.res)
args.state_shape = args.res
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num)
]
)
test_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num))
]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = DQN(
*args.state_shape, args.action_shape, device=args.device, features_only=True
)
actor = Actor(
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
)
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
# define policy
dist = torch.distributions.Categorical
policy = A2CPolicy(actor, critic, optim, dist).to(args.device)
if args.icm_lr_scale > 0:
feature_net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True
)
action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.net,
feature_dim,
action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device
)
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr)
policy = ICMPolicy(
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
args.icm_forward_loss_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return False

def watch():
# watch agent's performance
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
lens = result["lens"].mean() * args.skip_num
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f'Mean length (over {result["n/ep"]} episodes): {lens}')

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.repeat_per_collect,
args.test_num,
args.batch_size,
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
test_in_train=False
)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_a2c(get_args())
Loading

0 comments on commit a59d96d

Please sign in to comment.