-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement BCQPolicy and offline_bcq example (#480)
This PR implements BCQPolicy, which could be used to train an offline agent in the environment of continuous action space. An experimental result 'halfcheetah-expert-v1' is provided, which is a d4rl environment (for Offline Reinforcement Learning). Example usage is in the examples/offline/offline_bcq.py.
- Loading branch information
Showing
14 changed files
with
1,003 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Offline | ||
|
||
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. | ||
|
||
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. | ||
|
||
## Train | ||
|
||
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. | ||
|
||
To train an agent with BCQ algorithm: | ||
|
||
```bash | ||
python offline_bcq.py --task halfcheetah-expert-v1 | ||
``` | ||
|
||
After 1M steps: | ||
|
||
![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png) | ||
|
||
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment. | ||
|
||
## Results | ||
|
||
| Environment | BCQ | | ||
| --------------------- | --------------- | | ||
| halfcheetah-expert-v1 | 10624.0 ± 181.4 | | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
#!/usr/bin/env python3 | ||
import argparse | ||
import datetime | ||
import os | ||
import pprint | ||
|
||
import d4rl | ||
import gym | ||
import numpy as np | ||
import torch | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer | ||
from tianshou.env import SubprocVectorEnv | ||
from tianshou.policy import BCQPolicy | ||
from tianshou.trainer import offline_trainer | ||
from tianshou.utils import BasicLogger | ||
from tianshou.utils.net.common import MLP, Net | ||
from tianshou.utils.net.continuous import VAE, Critic, Perturbation | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--task', type=str, default='halfcheetah-expert-v1') | ||
parser.add_argument('--seed', type=int, default=0) | ||
parser.add_argument('--buffer-size', type=int, default=1000000) | ||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300]) | ||
parser.add_argument('--actor-lr', type=float, default=1e-3) | ||
parser.add_argument('--critic-lr', type=float, default=1e-3) | ||
parser.add_argument("--start-timesteps", type=int, default=10000) | ||
parser.add_argument('--epoch', type=int, default=200) | ||
parser.add_argument('--step-per-epoch', type=int, default=5000) | ||
parser.add_argument('--n-step', type=int, default=3) | ||
parser.add_argument('--batch-size', type=int, default=256) | ||
parser.add_argument('--training-num', type=int, default=10) | ||
parser.add_argument('--test-num', type=int, default=10) | ||
parser.add_argument('--logdir', type=str, default='log') | ||
parser.add_argument('--render', type=float, default=1 / 35) | ||
|
||
parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750]) | ||
# default to 2 * action_dim | ||
parser.add_argument('--latent-dim', type=int) | ||
parser.add_argument("--gamma", default=0.99) | ||
parser.add_argument("--tau", default=0.005) | ||
# Weighting for Clipped Double Q-learning in BCQ | ||
parser.add_argument("--lmbda", default=0.75) | ||
# Max perturbation hyper-parameter for BCQ | ||
parser.add_argument("--phi", default=0.05) | ||
parser.add_argument( | ||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' | ||
) | ||
parser.add_argument('--resume-path', type=str, default=None) | ||
parser.add_argument( | ||
'--watch', | ||
default=False, | ||
action='store_true', | ||
help='watch the play of pre-trained policy only', | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def test_bcq(): | ||
args = get_args() | ||
env = gym.make(args.task) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
args.max_action = env.action_space.high[0] # float | ||
print("device:", args.device) | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) | ||
|
||
args.state_dim = args.state_shape[0] | ||
args.action_dim = args.action_shape[0] | ||
print("Max_action", args.max_action) | ||
|
||
# train_envs = gym.make(args.task) | ||
train_envs = SubprocVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.training_num)] | ||
) | ||
# test_envs = gym.make(args.task) | ||
test_envs = SubprocVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.test_num)] | ||
) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
train_envs.seed(args.seed) | ||
test_envs.seed(args.seed) | ||
|
||
# model | ||
# perturbation network | ||
net_a = MLP( | ||
input_dim=args.state_dim + args.action_dim, | ||
output_dim=args.action_dim, | ||
hidden_sizes=args.hidden_sizes, | ||
device=args.device, | ||
) | ||
actor = Perturbation( | ||
net_a, max_action=args.max_action, device=args.device, phi=args.phi | ||
).to(args.device) | ||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||
|
||
net_c1 = Net( | ||
args.state_shape, | ||
args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
net_c2 = Net( | ||
args.state_shape, | ||
args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
critic1 = Critic(net_c1, device=args.device).to(args.device) | ||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) | ||
critic2 = Critic(net_c2, device=args.device).to(args.device) | ||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) | ||
|
||
# vae | ||
# output_dim = 0, so the last Module in the encoder is ReLU | ||
vae_encoder = MLP( | ||
input_dim=args.state_dim + args.action_dim, | ||
hidden_sizes=args.vae_hidden_sizes, | ||
device=args.device, | ||
) | ||
if not args.latent_dim: | ||
args.latent_dim = args.action_dim * 2 | ||
vae_decoder = MLP( | ||
input_dim=args.state_dim + args.latent_dim, | ||
output_dim=args.action_dim, | ||
hidden_sizes=args.vae_hidden_sizes, | ||
device=args.device, | ||
) | ||
vae = VAE( | ||
vae_encoder, | ||
vae_decoder, | ||
hidden_dim=args.vae_hidden_sizes[-1], | ||
latent_dim=args.latent_dim, | ||
max_action=args.max_action, | ||
device=args.device, | ||
).to(args.device) | ||
vae_optim = torch.optim.Adam(vae.parameters()) | ||
|
||
policy = BCQPolicy( | ||
actor, | ||
actor_optim, | ||
critic1, | ||
critic1_optim, | ||
critic2, | ||
critic2_optim, | ||
vae, | ||
vae_optim, | ||
device=args.device, | ||
gamma=args.gamma, | ||
tau=args.tau, | ||
lmbda=args.lmbda, | ||
) | ||
|
||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) | ||
print("Loaded agent from: ", args.resume_path) | ||
|
||
# collector | ||
if args.training_num > 1: | ||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) | ||
else: | ||
buffer = ReplayBuffer(args.buffer_size) | ||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) | ||
test_collector = Collector(policy, test_envs) | ||
train_collector.collect(n_step=args.start_timesteps, random=True) | ||
# log | ||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") | ||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' | ||
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file) | ||
writer = SummaryWriter(log_path) | ||
writer.add_text("args", str(args)) | ||
logger = BasicLogger(writer) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def watch(): | ||
if args.resume_path is None: | ||
args.resume_path = os.path.join(log_path, 'policy.pth') | ||
|
||
policy.load_state_dict( | ||
torch.load(args.resume_path, map_location=torch.device('cpu')) | ||
) | ||
policy.eval() | ||
collector = Collector(policy, env) | ||
collector.collect(n_episode=1, render=1 / 35) | ||
|
||
if not args.watch: | ||
dataset = d4rl.qlearning_dataset(env) | ||
dataset_size = dataset['rewards'].size | ||
|
||
print("dataset_size", dataset_size) | ||
replay_buffer = ReplayBuffer(dataset_size) | ||
|
||
for i in range(dataset_size): | ||
replay_buffer.add( | ||
Batch( | ||
obs=dataset['observations'][i], | ||
act=dataset['actions'][i], | ||
rew=dataset['rewards'][i], | ||
done=dataset['terminals'][i], | ||
obs_next=dataset['next_observations'][i], | ||
) | ||
) | ||
print("dataset loaded") | ||
# trainer | ||
result = offline_trainer( | ||
policy, | ||
replay_buffer, | ||
test_collector, | ||
args.epoch, | ||
args.step_per_epoch, | ||
args.test_num, | ||
args.batch_size, | ||
save_fn=save_fn, | ||
logger=logger, | ||
) | ||
pprint.pprint(result) | ||
else: | ||
watch() | ||
|
||
# Let's watch its performance! | ||
policy.eval() | ||
test_envs.seed(args.seed) | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=args.test_num, render=args.render) | ||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') | ||
|
||
|
||
if __name__ == '__main__': | ||
test_bcq() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.