Skip to content

Commit

Permalink
Implement TD3+BC for offline RL (#660)
Browse files Browse the repository at this point in the history
- implement TD3+BC for offline RL;
- fix a bug in trainer about test reward not logged because self.env_step is not set for offline setting;
  • Loading branch information
nuance1979 authored Jun 6, 2022
1 parent 9ce0a55 commit df35718
Show file tree
Hide file tree
Showing 10 changed files with 599 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
- Vanilla Imitation Learning
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
- [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf)
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ Imitation
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.TD3BCPolicy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.DiscreteBCQPolicy
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning <https://arxiv.org/pdf/2106.06860.pdf>`_
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
Expand Down
18 changes: 18 additions & 0 deletions examples/offline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c
| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |

### TD3+BC

| Environment | Dataset | CQL | Parameters |
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
| HalfCheetah-v2 | halfcheetah-expert-v2 | 11788.25 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 5741.13 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |

#### Observation normalization

Following the original paper, we use observation normalization by default. You can turn it off by setting `--norm-obs 0`. The difference are small but consistent.

| Dataset | w/ norm-obs | w/o norm-obs |
| :--- | :--- | :--- |
| halfcheeta-medium-v2 | 5741.13 | 5724.41 |
| halfcheeta-expert-v2 | 11788.25 | 11665.77 |
| walker2d-medium-v2 | 4051.76 | 3985.59 |
| walker2d-expert-v2 | 5068.15 | 5027.75 |

## Discrete control

For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.
Expand Down
227 changes: 227 additions & 0 deletions examples/offline/d4rl_td3_bc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#!/usr/bin/env python3

import argparse
import datetime
import os
import pprint

import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3BCPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--expert-data-task", type=str, default="halfcheetah-expert-v2"
)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=3e-4)
parser.add_argument("--critic-lr", type=float, default=3e-4)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=256)

parser.add_argument("--alpha", type=float, default=2.5)
parser.add_argument("--exploration-noise", type=float, default=0.1)
parser.add_argument("--policy-noise", type=float, default=0.2)
parser.add_argument("--noise-clip", type=float, default=0.5)
parser.add_argument("--update-actor-freq", type=int, default=2)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--norm-obs", type=int, default=1)

parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=1 / 35)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()


def test_td3_bc():
args = get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))

args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
print("Max_action", args.max_action)

test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
if args.norm_obs:
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)

# model
# actor network
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Actor(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

# critic network
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

policy = TD3BCPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
alpha=args.alpha,
estimation_step=args.n_step,
action_space=env.action_space,
)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
test_collector = Collector(policy, test_envs)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "td3_bc"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def watch():
if args.resume_path is None:
args.resume_path = os.path.join(log_path, "policy.pth")

policy.load_state_dict(
torch.load(args.resume_path, map_location=torch.device("cpu"))
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)

if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
if args.norm_obs:
replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer)
test_envs.set_obs_rms(obs_rms)
# trainer
result = offline_trainer(
policy,
replay_buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
)
pprint.pprint(result)
else:
watch()

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}")


if __name__ == "__main__":
test_td3_bc()
19 changes: 19 additions & 0 deletions examples/offline/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Tuple

import d4rl
import gym
import h5py
import numpy as np

from tianshou.data import ReplayBuffer
from tianshou.utils import RunningMeanStd


def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
Expand All @@ -27,3 +31,18 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
obs_next=dataset["next_observations"]
)
return buffer


def normalize_all_obs_in_replay_buffer(
replay_buffer: ReplayBuffer
) -> Tuple[ReplayBuffer, RunningMeanStd]:
# compute obs mean and var
obs_rms = RunningMeanStd()
obs_rms.update(replay_buffer.obs)
_eps = np.finfo(np.float32).eps.item()
# normalize obs
replay_buffer._meta["obs"] = (replay_buffer.obs -
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next -
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
return replay_buffer, obs_rms
Loading

0 comments on commit df35718

Please sign in to comment.