-
Notifications
You must be signed in to change notification settings - Fork 307
/
td3_bc.py
148 lines (121 loc) · 4.29 KB
/
td3_bc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""TD3+BC Example.
This is a self-contained example of an offline RL TD3+BC training script.
The helper functions are coded in the utils.py associated with this script.
"""
import time
import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
log_metrics,
make_environment,
make_loss_module,
make_offline_replay_buffer,
make_optimizer,
make_td3_agent,
)
@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.library).set()
# Create logger
exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="td3bc_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = cfg.network.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)
# Creante env
eval_env = make_environment(
cfg,
logger=logger,
)
# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
# Create agent
model, _ = make_td3_agent(cfg, eval_env, device)
# Create loss
loss_module, target_net_updater = make_loss_module(cfg.optim, model)
# Create optimizer
optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)
gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
update_counter = 0
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
start_time = time.time()
for i in pbar:
pbar.update(1)
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()
# Compute loss
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_loss.item()
to_log = {"q_loss": q_loss.item()}
# Update actor
if update_actor:
actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
# Update target params
target_net_updater.step()
to_log["actor_loss"] = actor_loss.item()
to_log.update(actorloss_metadata)
# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
if logger is not None:
log_metrics(logger, to_log, i)
if not eval_env.is_closed:
eval_env.close()
pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
if __name__ == "__main__":
main()