forked from tinkoff-ai/CORL
-
Notifications
You must be signed in to change notification settings - Fork 22
/
awac.py
504 lines (429 loc) · 16.9 KB
/
awac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional
import wandb
from tqdm import trange
TensorBatch = List[torch.Tensor]
@dataclass
class TrainConfig:
# wandb project name
project: str = "CORL"
# wandb group name
group: str = "AWAC-D4RL"
# wandb run name
name: str = "AWAC"
# training dataset and evaluation environment
env_name: str = "halfcheetah-medium-expert-v2"
# actor and critic hidden dim
hidden_dim: int = 256
# actor and critic learning rate
learning_rate: float = 3e-4
# discount factor
gamma: float = 0.99
# coefficient for the target critic Polyak's update
tau: float = 5e-3
# awac actor loss temperature, controlling balance
# between behaviour cloning and Q-value maximization
awac_lambda: float = 1.0
# total number of gradient updated during training
num_train_ops: int = 1_000_000
# training batch size
batch_size: int = 256
# maximum size of the replay buffer
buffer_size: int = 2_000_000
# whether to normalize reward (like in IQL)
normalize_reward: bool = False
# evaluation frequency, will evaluate every eval_frequency
# training steps
eval_frequency: int = 1000
# number of episodes to run during evaluation
n_test_episodes: int = 10
# path for checkpoints saving, optional
checkpoints_path: Optional[str] = None
# configure PyTorch to use deterministic algorithms instead
# of nondeterministic ones
deterministic_torch: bool = False
# training random seed
seed: int = 42
# evaluation random seed
test_seed: int = 69
# training device
device: str = "cuda"
def __post_init__(self):
self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}"
if self.checkpoints_path is not None:
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)
class ReplayBuffer:
def __init__(
self,
state_dim: int,
action_dim: int,
buffer_size: int,
device: str = "cpu",
):
self._buffer_size = buffer_size
self._pointer = 0
self._size = 0
self._states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._actions = torch.zeros(
(buffer_size, action_dim), dtype=torch.float32, device=device
)
self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._next_states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._device = device
def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
return torch.tensor(data, dtype=torch.float32, device=self._device)
def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
if self._size != 0:
raise ValueError("Trying to load data into non-empty replay buffer")
n_transitions = data["observations"].shape[0]
if n_transitions > self._buffer_size:
raise ValueError(
"Replay buffer is smaller than the dataset you are trying to load!"
)
self._states[:n_transitions] = self._to_tensor(data["observations"])
self._actions[:n_transitions] = self._to_tensor(data["actions"])
self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
self._size += n_transitions
self._pointer = min(self._size, n_transitions)
print(f"Dataset size: {n_transitions}")
def sample(self, batch_size: int) -> TensorBatch:
indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
states = self._states[indices]
actions = self._actions[indices]
rewards = self._rewards[indices]
next_states = self._next_states[indices]
dones = self._dones[indices]
return [states, actions, rewards, next_states, dones]
def add_transition(self):
# Use this method to add new data into the replay buffer during fine-tuning.
# I left it unimplemented since now we do not do fine-tuning.
raise NotImplementedError
class Actor(nn.Module):
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dim: int,
min_log_std: float = -20.0,
max_log_std: float = 2.0,
min_action: float = -1.0,
max_action: float = 1.0,
):
super().__init__()
self._mlp = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
)
self._log_std = nn.Parameter(torch.zeros(action_dim, dtype=torch.float32))
self._min_log_std = min_log_std
self._max_log_std = max_log_std
self._min_action = min_action
self._max_action = max_action
def _get_policy(self, state: torch.Tensor) -> torch.distributions.Distribution:
mean = self._mlp(state)
log_std = self._log_std.clamp(self._min_log_std, self._max_log_std)
policy = torch.distributions.Normal(mean, log_std.exp())
return policy
def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
policy = self._get_policy(state)
log_prob = policy.log_prob(action).sum(-1, keepdim=True)
return log_prob
def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
policy = self._get_policy(state)
action = policy.rsample()
action.clamp_(self._min_action, self._max_action)
log_prob = policy.log_prob(action).sum(-1, keepdim=True)
return action, log_prob
def act(self, state: np.ndarray, device: str) -> np.ndarray:
state_t = torch.tensor(state[None], dtype=torch.float32, device=device)
policy = self._get_policy(state_t)
if self._mlp.training:
action_t = policy.sample()
else:
action_t = policy.mean
action = action_t[0].cpu().numpy()
return action
class Critic(nn.Module):
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dim: int,
):
super().__init__()
self._mlp = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
q_value = self._mlp(torch.cat([state, action], dim=-1))
return q_value
def soft_update(target: nn.Module, source: nn.Module, tau: float):
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)
class AdvantageWeightedActorCritic:
def __init__(
self,
actor: nn.Module,
actor_optimizer: torch.optim.Optimizer,
critic_1: nn.Module,
critic_1_optimizer: torch.optim.Optimizer,
critic_2: nn.Module,
critic_2_optimizer: torch.optim.Optimizer,
gamma: float = 0.99,
tau: float = 5e-3, # parameter for the soft target update,
awac_lambda: float = 1.0,
exp_adv_max: float = 100.0,
):
self._actor = actor
self._actor_optimizer = actor_optimizer
self._critic_1 = critic_1
self._critic_1_optimizer = critic_1_optimizer
self._target_critic_1 = deepcopy(critic_1)
self._critic_2 = critic_2
self._critic_2_optimizer = critic_2_optimizer
self._target_critic_2 = deepcopy(critic_2)
self._gamma = gamma
self._tau = tau
self._awac_lambda = awac_lambda
self._exp_adv_max = exp_adv_max
def _actor_loss(self, states, actions):
with torch.no_grad():
pi_action, _ = self._actor(states)
v = torch.min(
self._critic_1(states, pi_action), self._critic_2(states, pi_action)
)
q = torch.min(
self._critic_1(states, actions), self._critic_2(states, actions)
)
adv = q - v
weights = torch.clamp_max(
torch.exp(adv / self._awac_lambda), self._exp_adv_max
)
action_log_prob = self._actor.log_prob(states, actions)
loss = (-action_log_prob * weights).mean()
return loss
def _critic_loss(self, states, actions, rewards, dones, next_states):
with torch.no_grad():
next_actions, _ = self._actor(next_states)
q_next = torch.min(
self._target_critic_1(next_states, next_actions),
self._target_critic_2(next_states, next_actions),
)
q_target = rewards + self._gamma * (1.0 - dones) * q_next
q1 = self._critic_1(states, actions)
q2 = self._critic_2(states, actions)
q1_loss = nn.functional.mse_loss(q1, q_target)
q2_loss = nn.functional.mse_loss(q2, q_target)
loss = q1_loss + q2_loss
return loss
def _update_critic(self, states, actions, rewards, dones, next_states):
loss = self._critic_loss(states, actions, rewards, dones, next_states)
self._critic_1_optimizer.zero_grad()
self._critic_2_optimizer.zero_grad()
loss.backward()
self._critic_1_optimizer.step()
self._critic_2_optimizer.step()
return loss.item()
def _update_actor(self, states, actions):
loss = self._actor_loss(states, actions)
self._actor_optimizer.zero_grad()
loss.backward()
self._actor_optimizer.step()
return loss.item()
def update(self, batch: TensorBatch) -> Dict[str, float]:
states, actions, rewards, next_states, dones = batch
critic_loss = self._update_critic(states, actions, rewards, dones, next_states)
actor_loss = self._update_actor(states, actions)
soft_update(self._target_critic_1, self._critic_1, self._tau)
soft_update(self._target_critic_2, self._critic_2, self._tau)
result = {"critic_loss": critic_loss, "actor_loss": actor_loss}
return result
def state_dict(self) -> Dict[str, Any]:
return {
"actor": self._actor.state_dict(),
"critic_1": self._critic_1.state_dict(),
"critic_2": self._critic_2.state_dict(),
}
def load_state_dict(self, state_dict: Dict[str, Any]):
self._actor.load_state_dict(state_dict["actor"])
self._critic_1.load_state_dict(state_dict["critic_1"])
self._critic_2.load_state_dict(state_dict["critic_2"])
def set_seed(
seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
if env is not None:
env.seed(seed)
env.action_space.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(deterministic_torch)
def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
mean = states.mean(0)
std = states.std(0) + eps
return mean, std
def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
return (states - mean) / std
def wrap_env(
env: gym.Env,
state_mean: Union[np.ndarray, float] = 0.0,
state_std: Union[np.ndarray, float] = 1.0,
) -> gym.Env:
def normalize_state(state):
return (state - state_mean) / state_std
env = gym.wrappers.TransformObservation(env, normalize_state)
return env
@torch.no_grad()
def eval_actor(
env: gym.Env, actor: Actor, device: str, n_episodes: int, seed: int
) -> np.ndarray:
env.seed(seed)
actor.eval()
episode_rewards = []
for _ in range(n_episodes):
state, done = env.reset(), False
episode_reward = 0.0
while not done:
action = actor.act(state, device)
state, reward, done, _ = env.step(action)
episode_reward += reward
episode_rewards.append(episode_reward)
actor.train()
return np.asarray(episode_rewards)
def return_reward_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
def modify_reward(dataset, env_name, max_episode_steps=1000):
if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
dataset["rewards"] /= max_ret - min_ret
dataset["rewards"] *= max_episode_steps
elif "antmaze" in env_name:
dataset["rewards"] -= 1.0
def wandb_init(config: dict) -> None:
wandb.init(
config=config,
project=config["project"],
group=config["group"],
name=config["name"],
id=str(uuid.uuid4()),
)
wandb.run.save()
@pyrallis.wrap()
def train(config: TrainConfig):
env = gym.make(config.env_name)
set_seed(config.seed, env, deterministic_torch=config.deterministic_torch)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
dataset = d4rl.qlearning_dataset(env)
if config.normalize_reward:
modify_reward(dataset, config.env_name)
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
dataset["observations"] = normalize_states(
dataset["observations"], state_mean, state_std
)
dataset["next_observations"] = normalize_states(
dataset["next_observations"], state_mean, state_std
)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)
replay_buffer = ReplayBuffer(
state_dim,
action_dim,
config.buffer_size,
config.device,
)
replay_buffer.load_d4rl_dataset(dataset)
actor_critic_kwargs = {
"state_dim": state_dim,
"action_dim": action_dim,
"hidden_dim": config.hidden_dim,
}
actor = Actor(**actor_critic_kwargs)
actor.to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.learning_rate)
critic_1 = Critic(**actor_critic_kwargs)
critic_2 = Critic(**actor_critic_kwargs)
critic_1.to(config.device)
critic_2.to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=config.learning_rate)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=config.learning_rate)
awac = AdvantageWeightedActorCritic(
actor=actor,
actor_optimizer=actor_optimizer,
critic_1=critic_1,
critic_1_optimizer=critic_1_optimizer,
critic_2=critic_2,
critic_2_optimizer=critic_2_optimizer,
gamma=config.gamma,
tau=config.tau,
awac_lambda=config.awac_lambda,
)
wandb_init(asdict(config))
if config.checkpoints_path is not None:
print(f"Checkpoints path: {config.checkpoints_path}")
os.makedirs(config.checkpoints_path, exist_ok=True)
with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
pyrallis.dump(config, f)
for t in trange(config.num_train_ops, ncols=80):
batch = replay_buffer.sample(config.batch_size)
batch = [b.to(config.device) for b in batch]
update_result = awac.update(batch)
wandb.log(update_result, step=t)
if (t + 1) % config.eval_frequency == 0:
eval_scores = eval_actor(
env, actor, config.device, config.n_test_episodes, config.test_seed
)
wandb.log({"eval_score": eval_scores.mean()}, step=t)
if hasattr(env, "get_normalized_score"):
normalized_eval_scores = env.get_normalized_score(eval_scores) * 100.0
wandb.log(
{"d4rl_normalized_score": normalized_eval_scores.mean()}, step=t
)
if config.checkpoints_path is not None:
torch.save(
awac.state_dict(),
os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
)
wandb.finish()
if __name__ == "__main__":
train()