Skip to content

Commit

Permalink
Implement BCQPolicy and offline_bcq example (#480)
Browse files Browse the repository at this point in the history
This PR implements BCQPolicy, which could be used to train an offline agent in the environment of continuous action space. An experimental result 'halfcheetah-expert-v1' is provided, which is a d4rl environment (for Offline Reinforcement Learning).
Example usage is in the examples/offline/offline_bcq.py.
  • Loading branch information
thkkk authored Nov 22, 2021
1 parent 94d3b27 commit 5c5a3db
Show file tree
Hide file tree
Showing 14 changed files with 1,003 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ Imitation
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.BCQPolicy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.DiscreteBCQPolicy
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
Expand Down
28 changes: 28 additions & 0 deletions examples/offline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Offline

In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.

Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.

## Train

Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.

To train an agent with BCQ algorithm:

```bash
python offline_bcq.py --task halfcheetah-expert-v1
```

After 1M steps:

![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png)

`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment.

## Results

| Environment | BCQ |
| --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |

241 changes: 241 additions & 0 deletions examples/offline/offline_bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint

import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='halfcheetah-expert-v1')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=1000000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)

parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750])
# default to 2 * action_dim
parser.add_argument('--latent-dim', type=int)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
return parser.parse_args()


def test_bcq():
args = get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))

args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
print("Max_action", args.max_action)

# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)

# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(
net_a, max_action=args.max_action, device=args.device, phi=args.phi
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())

policy = BCQPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
vae,
vae_optim,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def watch():
if args.resume_path is None:
args.resume_path = os.path.join(log_path, 'policy.pth')

policy.load_state_dict(
torch.load(args.resume_path, map_location=torch.device('cpu'))
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)

if not args.watch:
dataset = d4rl.qlearning_dataset(env)
dataset_size = dataset['rewards'].size

print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)

for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset['observations'][i],
act=dataset['actions'][i],
rew=dataset['rewards'][i],
done=dataset['terminals'][i],
obs_next=dataset['next_observations'][i],
)
)
print("dataset loaded")
# trainer
result = offline_trainer(
policy,
replay_buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_fn=save_fn,
logger=logger,
)
pprint.pprint(result)
else:
watch()

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_bcq()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
SubprocVectorEnv(env_fns),
ShmemVectorEnv(env_fns),
]
if has_ray():
if has_ray() and sys.platform == "linux":
venv += [RayVectorEnv(env_fns)]
for v in venv:
v.seed(0)
Expand Down
Empty file added test/offline/__init__.py
Empty file.
Loading

0 comments on commit 5c5a3db

Please sign in to comment.