-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement TD3+BC for offline RL (#660)
- implement TD3+BC for offline RL; - fix a bug in trainer about test reward not logged because self.env_step is not set for offline setting;
- Loading branch information
1 parent
9ce0a55
commit df35718
Showing
10 changed files
with
599 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import datetime | ||
import os | ||
import pprint | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer | ||
from tianshou.data import Collector | ||
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs | ||
from tianshou.exploration import GaussianNoise | ||
from tianshou.policy import TD3BCPolicy | ||
from tianshou.trainer import offline_trainer | ||
from tianshou.utils import TensorboardLogger, WandbLogger | ||
from tianshou.utils.net.common import Net | ||
from tianshou.utils.net.continuous import Actor, Critic | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, default="HalfCheetah-v2") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument( | ||
"--expert-data-task", type=str, default="halfcheetah-expert-v2" | ||
) | ||
parser.add_argument("--buffer-size", type=int, default=1000000) | ||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) | ||
parser.add_argument("--actor-lr", type=float, default=3e-4) | ||
parser.add_argument("--critic-lr", type=float, default=3e-4) | ||
parser.add_argument("--epoch", type=int, default=200) | ||
parser.add_argument("--step-per-epoch", type=int, default=5000) | ||
parser.add_argument("--n-step", type=int, default=3) | ||
parser.add_argument("--batch-size", type=int, default=256) | ||
|
||
parser.add_argument("--alpha", type=float, default=2.5) | ||
parser.add_argument("--exploration-noise", type=float, default=0.1) | ||
parser.add_argument("--policy-noise", type=float, default=0.2) | ||
parser.add_argument("--noise-clip", type=float, default=0.5) | ||
parser.add_argument("--update-actor-freq", type=int, default=2) | ||
parser.add_argument("--tau", type=float, default=0.005) | ||
parser.add_argument("--gamma", type=float, default=0.99) | ||
parser.add_argument("--norm-obs", type=int, default=1) | ||
|
||
parser.add_argument("--eval-freq", type=int, default=1) | ||
parser.add_argument("--test-num", type=int, default=10) | ||
parser.add_argument("--logdir", type=str, default="log") | ||
parser.add_argument("--render", type=float, default=1 / 35) | ||
parser.add_argument( | ||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | ||
) | ||
parser.add_argument("--resume-path", type=str, default=None) | ||
parser.add_argument("--resume-id", type=str, default=None) | ||
parser.add_argument( | ||
"--logger", | ||
type=str, | ||
default="tensorboard", | ||
choices=["tensorboard", "wandb"], | ||
) | ||
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") | ||
parser.add_argument( | ||
"--watch", | ||
default=False, | ||
action="store_true", | ||
help="watch the play of pre-trained policy only", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def test_td3_bc(): | ||
args = get_args() | ||
env = gym.make(args.task) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
args.max_action = env.action_space.high[0] # float | ||
print("device:", args.device) | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) | ||
|
||
args.state_dim = args.state_shape[0] | ||
args.action_dim = args.action_shape[0] | ||
print("Max_action", args.max_action) | ||
|
||
test_envs = SubprocVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.test_num)] | ||
) | ||
if args.norm_obs: | ||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) | ||
|
||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
test_envs.seed(args.seed) | ||
|
||
# model | ||
# actor network | ||
net_a = Net( | ||
args.state_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
device=args.device, | ||
) | ||
actor = Actor( | ||
net_a, | ||
action_shape=args.action_shape, | ||
max_action=args.max_action, | ||
device=args.device, | ||
).to(args.device) | ||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||
|
||
# critic network | ||
net_c1 = Net( | ||
args.state_shape, | ||
args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
net_c2 = Net( | ||
args.state_shape, | ||
args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
critic1 = Critic(net_c1, device=args.device).to(args.device) | ||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) | ||
critic2 = Critic(net_c2, device=args.device).to(args.device) | ||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) | ||
|
||
policy = TD3BCPolicy( | ||
actor, | ||
actor_optim, | ||
critic1, | ||
critic1_optim, | ||
critic2, | ||
critic2_optim, | ||
tau=args.tau, | ||
gamma=args.gamma, | ||
exploration_noise=GaussianNoise(sigma=args.exploration_noise), | ||
policy_noise=args.policy_noise, | ||
update_actor_freq=args.update_actor_freq, | ||
noise_clip=args.noise_clip, | ||
alpha=args.alpha, | ||
estimation_step=args.n_step, | ||
action_space=env.action_space, | ||
) | ||
|
||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) | ||
print("Loaded agent from: ", args.resume_path) | ||
|
||
# collector | ||
test_collector = Collector(policy, test_envs) | ||
|
||
# log | ||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||
args.algo_name = "td3_bc" | ||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) | ||
log_path = os.path.join(args.logdir, log_name) | ||
|
||
# logger | ||
if args.logger == "wandb": | ||
logger = WandbLogger( | ||
save_interval=1, | ||
name=log_name.replace(os.path.sep, "__"), | ||
run_id=args.resume_id, | ||
config=args, | ||
project=args.wandb_project, | ||
) | ||
writer = SummaryWriter(log_path) | ||
writer.add_text("args", str(args)) | ||
if args.logger == "tensorboard": | ||
logger = TensorboardLogger(writer) | ||
else: # wandb | ||
logger.load(writer) | ||
|
||
def save_best_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) | ||
|
||
def watch(): | ||
if args.resume_path is None: | ||
args.resume_path = os.path.join(log_path, "policy.pth") | ||
|
||
policy.load_state_dict( | ||
torch.load(args.resume_path, map_location=torch.device("cpu")) | ||
) | ||
policy.eval() | ||
collector = Collector(policy, env) | ||
collector.collect(n_episode=1, render=1 / 35) | ||
|
||
if not args.watch: | ||
replay_buffer = load_buffer_d4rl(args.expert_data_task) | ||
if args.norm_obs: | ||
replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) | ||
test_envs.set_obs_rms(obs_rms) | ||
# trainer | ||
result = offline_trainer( | ||
policy, | ||
replay_buffer, | ||
test_collector, | ||
args.epoch, | ||
args.step_per_epoch, | ||
args.test_num, | ||
args.batch_size, | ||
save_best_fn=save_best_fn, | ||
logger=logger, | ||
) | ||
pprint.pprint(result) | ||
else: | ||
watch() | ||
|
||
# Let's watch its performance! | ||
policy.eval() | ||
test_envs.seed(args.seed) | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=args.test_num, render=args.render) | ||
print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_td3_bc() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.