Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Intrinsic Curiosity Module #503

Merged
merged 17 commits into from
Jan 14, 2022
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)

Here is Tianshou's other features:

Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ Model-based
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.ICMPolicy
:members:
:undoc-members:
:show-inheritance:

Multi-agent
-----------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

Expand Down
43 changes: 41 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import IntrinsicCuriosityModule


def get_args():
Expand Down Expand Up @@ -55,6 +57,24 @@ def get_args():
help='watch the play of pre-trained policy only'
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
)
parser.add_argument(
'--icm-reward-scale',
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
)
parser.add_argument(
'--icm-forward-loss-weight',
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
)
return parser.parse_args()


Expand Down Expand Up @@ -101,6 +121,24 @@ def test_dqn(args=get_args()):
args.n_step,
target_update_freq=args.target_update_freq
)
if args.icm_lr_scale > 0:
feature_net = DQN(
*args.state_shape, args.action_shape, args.device, features_only=True
)
action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.net,
feature_dim,
action_dim,
hidden_sizes=[512],
device=args.device
)
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
policy = ICMPolicy(
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
args.icm_forward_loss_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
Expand All @@ -118,7 +156,8 @@ def test_dqn(args=get_args()):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
log_path = os.path.join(args.logdir, args.task, log_name)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
Expand All @@ -127,7 +166,7 @@ def test_dqn(args=get_args()):
logger = WandbLogger(
save_interval=1,
project=args.task,
name='dqn',
name=log_name,
run_id=args.resume_id,
config=args,
)
Expand Down
204 changes: 204 additions & 0 deletions test/modelbased/test_dqn_icm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import argparse
import os
import pprint

import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import DQNPolicy, ICMPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.discrete import IntrinsicCuriosityModule


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument(
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument(
'--lr-scale',
type=float,
default=1.,
help='use intrinsic curiosity module with this lr scale'
)
parser.add_argument(
'--reward-scale',
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
)
parser.add_argument(
'--forward-loss-weight',
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
)
args = parser.parse_known_args()[0]
return args


def test_dqn_icm(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# Q_param = V_param = {"hidden_sizes": [128]}
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
# dueling=(Q_param, V_param),
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq,
)
feature_dim = args.hidden_sizes[-1]
feature_net = MLP(
np.prod(args.state_shape),
output_dim=feature_dim,
hidden_sizes=args.hidden_sizes[:-1],
device=args.device
)
action_dim = np.prod(args.action_shape)
icm_net = IntrinsicCuriosityModule(
feature_net,
feature_dim,
action_dim,
hidden_sizes=args.hidden_sizes[-1:],
device=args.device
).to(args.device)
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
policy = ICMPolicy(
policy, icm_net, icm_optim, args.lr_scale, args.reward_scale,
args.forward_loss_weight
)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn_icm')
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
update_per_step=args.update_per_step,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


def test_pdqn_icm(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_dqn_icm(args)


if __name__ == '__main__':
test_dqn_icm(get_args())
Loading