-
Notifications
You must be signed in to change notification settings - Fork 381
/
trex_reward_model.py
440 lines (394 loc) · 19.9 KB
/
trex_reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
from copy import deepcopy
from typing import Tuple, Optional, List, Dict
from easydict import EasyDict
import pickle
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from ding.utils import REWARD_MODEL_REGISTRY
from ding.utils import SequenceType
from ding.model.common import FCEncoder
from ding.utils import build_logger
from ding.utils.data import default_collate
from .base_reward_model import BaseRewardModel
from .rnd_reward_model import collect_states
class TrexConvEncoder(nn.Module):
r"""
Overview:
The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
obs_shape: SequenceType,
hidden_size_list: SequenceType = [16, 16, 16, 16, 64, 1],
activation: Optional[nn.Module] = nn.LeakyReLU()
) -> None:
r"""
Overview:
Init the Trex Convolution Encoder according to arguments. TrexConvEncoder is different \
from the ConvEncoder in model.common.encoder, their stride and kernel size parameters \
are different
Arguments:
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size``
- hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size``
- activation (:obj:`nn.Module`):
The type of activation to use in the conv ``layers``,
if ``None`` then default set to ``nn.LeakyReLU()``
"""
super(TrexConvEncoder, self).__init__()
self.obs_shape = obs_shape
self.act = activation
self.hidden_size_list = hidden_size_list
layers = []
kernel_size = [7, 5, 3, 3]
stride = [3, 2, 1, 1]
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
layers.append(nn.Flatten())
self.main = nn.Sequential(*layers)
flatten_size = self._get_flatten_size()
self.mid = nn.Sequential(
nn.Linear(flatten_size, hidden_size_list[-2]), self.act,
nn.Linear(hidden_size_list[-2], hidden_size_list[-1])
)
def _get_flatten_size(self) -> int:
r"""
Overview:
Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``.
Arguments:
- x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main``
Returns:
- outputs (:obj:`torch.Tensor`): Size int, also number of in-feature
"""
test_data = torch.randn(1, *self.obs_shape)
with torch.no_grad():
output = self.main(test_data)
return output.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Overview:
Return embedding tensor of the env observation
Arguments:
- x (:obj:`torch.Tensor`): Env raw observation
Returns:
- outputs (:obj:`torch.Tensor`): Embedding tensor
"""
x = self.main(x)
x = self.mid(x)
return x
class TrexModel(nn.Module):
def __init__(self, obs_shape):
super(TrexModel, self).__init__()
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = nn.Sequential(FCEncoder(obs_shape, [512, 64]), nn.Linear(64, 1))
# Conv Encoder
elif len(obs_shape) == 3:
self.encoder = TrexConvEncoder(obs_shape)
else:
raise KeyError(
"not support obs_shape for pre-defined encoder: {}, please customize your own Trex model".
format(obs_shape)
)
def cum_return(self, traj: torch.Tensor, mode: str = 'sum') -> Tuple[torch.Tensor, torch.Tensor]:
'''calculate cumulative return of trajectory'''
r = self.encoder(traj)
if mode == 'sum':
sum_rewards = torch.sum(r)
sum_abs_rewards = torch.sum(torch.abs(r))
return sum_rewards, sum_abs_rewards
elif mode == 'batch':
return r, torch.abs(r)
else:
raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode))
def forward(self, traj_i: torch.Tensor, traj_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
'''compute cumulative return for each trajectory and return logits'''
cum_r_i, abs_r_i = self.cum_return(traj_i)
cum_r_j, abs_r_j = self.cum_return(traj_j)
return torch.cat((cum_r_i.unsqueeze(0), cum_r_j.unsqueeze(0)), 0), abs_r_i + abs_r_j
@REWARD_MODEL_REGISTRY.register('trex')
class TrexRewardModel(BaseRewardModel):
"""
Overview:
The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf)
Interface:
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
``__init__``, ``_train``,
Config:
== ==================== ====== ============= ============================================ =============
ID Symbol Type Default Value Description Other(Shape)
== ==================== ====== ============= ============================================ =============
1 ``type`` str trex | Reward model register name, refer |
| to registry ``REWARD_MODEL_REGISTRY`` |
3 | ``learning_rate`` float 0.00001 | learning rate for optimizer |
4 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
5 | ``num_trajs`` int 0 | Number of downsampled full trajectories |
6 | ``num_snippets`` int 6000 | Number of short subtrajectories to sample |
== ==================== ====== ============= ============================================ =============
"""
config = dict(
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
type='trex',
# (float) The step size of gradient descent.
learning_rate=1e-5,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=100,
# (int) Number of downsampled full trajectories.
num_trajs=0,
# (int) Number of short subtrajectories to sample.
num_snippets=6000,
)
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature.
Arguments:
- cfg (:obj:`EasyDict`): Training config
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary
"""
super(TrexRewardModel, self).__init__()
self.cfg = config
assert device in ["cpu", "cuda"] or "cuda" in device
self.device = device
self.tb_logger = tb_logger
self.reward_model = TrexModel(self.cfg.policy.model.obs_shape)
self.reward_model.to(self.device)
self.pre_expert_data = []
self.train_data = []
self.expert_data_loader = None
self.opt = optim.Adam(self.reward_model.parameters(), config.reward_model.learning_rate)
self.train_iter = 0
self.learning_returns = []
self.training_obs = []
self.training_labels = []
self.num_trajs = self.cfg.reward_model.num_trajs
self.num_snippets = self.cfg.reward_model.num_snippets
# minimum number of short subtrajectories to sample
self.min_snippet_length = config.reward_model.min_snippet_length
# maximum number of short subtrajectories to sample
self.max_snippet_length = config.reward_model.max_snippet_length
self.l1_reg = 0
self.data_for_save = {}
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self.cfg.exp_name, 'trex_reward_model'), name='trex_reward_model'
)
self.load_expert_data()
def load_expert_data(self) -> None:
"""
Overview:
Getting the expert data.
Effects:
This is a side effect function which updates the expert data attribute \
(i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs``
"""
with open(os.path.join(self.cfg.exp_name, 'episodes_data.pkl'), 'rb') as f:
self.pre_expert_data = pickle.load(f)
with open(os.path.join(self.cfg.exp_name, 'learning_returns.pkl'), 'rb') as f:
self.learning_returns = pickle.load(f)
self.create_training_data()
self._logger.info("num_training_obs: {}".format(len(self.training_obs)))
self._logger.info("num_labels: {}".format(len(self.training_labels)))
def create_training_data(self):
num_trajs = self.num_trajs
num_snippets = self.num_snippets
min_snippet_length = self.min_snippet_length
max_snippet_length = self.max_snippet_length
demo_lengths = []
for i in range(len(self.pre_expert_data)):
demo_lengths.append([len(d) for d in self.pre_expert_data[i]])
self._logger.info("demo_lengths: {}".format(demo_lengths))
max_snippet_length = min(np.min(demo_lengths), max_snippet_length)
self._logger.info("min snippet length: {}".format(min_snippet_length))
self._logger.info("max snippet length: {}".format(max_snippet_length))
# collect training data
max_traj_length = 0
num_bins = len(self.pre_expert_data)
assert num_bins >= 2
# add full trajs (for use on Enduro)
si = np.random.randint(6, size=num_trajs)
sj = np.random.randint(6, size=num_trajs)
step = np.random.randint(3, 7, size=num_trajs)
for n in range(num_trajs):
# pick two random demonstrations
bi, bj = np.random.choice(num_bins, size=(2, ), replace=False)
ti = np.random.choice(len(self.pre_expert_data[bi]))
tj = np.random.choice(len(self.pre_expert_data[bj]))
# create random partial trajs by finding random start frame and random skip frame
traj_i = self.pre_expert_data[bi][ti][si[n]::step[n]] # slice(start,stop,step)
traj_j = self.pre_expert_data[bj][tj][sj[n]::step[n]]
label = int(bi <= bj)
self.training_obs.append((traj_i, traj_j))
self.training_labels.append(label)
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
# fixed size snippets with progress prior
rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets)
for n in range(num_snippets):
# pick two random demonstrations
bi, bj = np.random.choice(num_bins, size=(2, ), replace=False)
ti = np.random.choice(len(self.pre_expert_data[bi]))
tj = np.random.choice(len(self.pre_expert_data[bj]))
# create random snippets
# find min length of both demos to ensure we can pick a demo no earlier
# than that chosen in worse preferred demo
min_length = min(len(self.pre_expert_data[bi][ti]), len(self.pre_expert_data[bj][tj]))
if bi < bj: # pick tj snippet to be later than ti
ti_start = np.random.randint(min_length - rand_length[n] + 1)
# print(ti_start, len(demonstrations[tj]))
tj_start = np.random.randint(ti_start, len(self.pre_expert_data[bj][tj]) - rand_length[n] + 1)
else: # ti is better so pick later snippet in ti
tj_start = np.random.randint(min_length - rand_length[n] + 1)
# print(tj_start, len(demonstrations[ti]))
ti_start = np.random.randint(tj_start, len(self.pre_expert_data[bi][ti]) - rand_length[n] + 1)
# skip everyother framestack to reduce size
traj_i = self.pre_expert_data[bi][ti][ti_start:ti_start + rand_length[n]:2]
traj_j = self.pre_expert_data[bj][tj][tj_start:tj_start + rand_length[n]:2]
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
label = int(bi <= bj)
self.training_obs.append((traj_i, traj_j))
self.training_labels.append(label)
self._logger.info(("maximum traj length: {}".format(max_traj_length)))
return self.training_obs, self.training_labels
def _train(self):
# check if gpu available
device = self.device # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
self._logger.info("device: {}".format(device))
training_inputs, training_outputs = self.training_obs, self.training_labels
loss_criterion = nn.CrossEntropyLoss()
cum_loss = 0.0
training_data = list(zip(training_inputs, training_outputs))
for epoch in range(self.cfg.reward_model.update_per_collect): # todo
np.random.shuffle(training_data)
training_obs, training_labels = zip(*training_data)
for i in range(len(training_labels)):
# traj_i, traj_j has the same length, however, they change as i increases
traj_i, traj_j = training_obs[i] # traj_i is a list of array generated by env.step
traj_i = np.array(traj_i)
traj_j = np.array(traj_j)
traj_i = torch.from_numpy(traj_i).float().to(device)
traj_j = torch.from_numpy(traj_j).float().to(device)
# training_labels[i] is a boolean integer: 0 or 1
labels = torch.tensor([training_labels[i]]).to(device)
# forward + backward + zero out gradient + optimize
outputs, abs_rewards = self.reward_model.forward(traj_i, traj_j)
outputs = outputs.unsqueeze(0)
loss = loss_criterion(outputs, labels) + self.l1_reg * abs_rewards
self.opt.zero_grad()
loss.backward()
self.opt.step()
# print stats to see if learning
item_loss = loss.item()
cum_loss += item_loss
if i % 100 == 99:
self._logger.info("[epoch {}:{}] loss {}".format(epoch, i, cum_loss))
self._logger.info("abs_returns: {}".format(abs_rewards))
cum_loss = 0.0
self._logger.info("check pointing")
if not os.path.exists(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')):
os.makedirs(os.path.join(self.cfg.exp_name, 'ckpt_reward_model'))
torch.save(self.reward_model.state_dict(), os.path.join(self.cfg.exp_name, 'ckpt_reward_model/latest.pth.tar'))
self._logger.info("finished training")
def train(self):
self._train()
# print out predicted cumulative returns and actual returns
sorted_returns = sorted(self.learning_returns, key=lambda s: s[0])
demonstrations = [
x for _, x in sorted(zip(self.learning_returns, self.pre_expert_data), key=lambda pair: pair[0][0])
]
with torch.no_grad():
pred_returns = [self.predict_traj_return(self.reward_model, traj[0]) for traj in demonstrations]
for i, p in enumerate(pred_returns):
self._logger.info("{} {} {}".format(i, p, sorted_returns[i][0]))
info = {
"demo_length": [len(d[0]) for d in self.pre_expert_data],
"min_snippet_length": self.min_snippet_length,
"max_snippet_length": min(np.min([len(d[0]) for d in self.pre_expert_data]), self.max_snippet_length),
"len_num_training_obs": len(self.training_obs),
"lem_num_labels": len(self.training_labels),
"accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels),
}
self._logger.info(
"accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
)
def predict_traj_return(self, net, traj):
device = self.device
# torch.set_printoptions(precision=20)
# torch.use_deterministic_algorithms(True)
with torch.no_grad():
rewards_from_obs = net.cum_return(
torch.from_numpy(np.array(traj)).float().to(device), mode='batch'
)[0].squeeze().tolist()
# rewards_from_obs1 = net.cum_return(torch.from_numpy(np.array([traj[0]])).float().to(device))[0].item()
# different precision
return sum(rewards_from_obs) # rewards_from_obs is a list of floats
def calc_accuracy(self, reward_network, training_inputs, training_outputs):
device = self.device
loss_criterion = nn.CrossEntropyLoss()
num_correct = 0.
with torch.no_grad():
for i in range(len(training_inputs)):
label = training_outputs[i]
traj_i, traj_j = training_inputs[i]
traj_i = np.array(traj_i)
traj_j = np.array(traj_j)
traj_i = torch.from_numpy(traj_i).float().to(device)
traj_j = torch.from_numpy(traj_j).float().to(device)
#forward to get logits
outputs, abs_return = reward_network.forward(traj_i, traj_j)
_, pred_label = torch.max(outputs, 0)
if pred_label.item() == label:
num_correct += 1.
return num_correct / len(training_inputs)
def pred_data(self, data):
obs = [default_collate(data[i])['obs'] for i in range(len(data))]
res = [torch.sum(default_collate(data[i])['reward']).item() for i in range(len(data))]
pred_returns = [self.predict_traj_return(self.reward_model, obs[i]) for i in range(len(obs))]
return {'real': res, 'pred': pred_returns}
def estimate(self, data: list) -> List[Dict]:
"""
Overview:
Estimate reward by rewriting the reward key in each row of the data.
Arguments:
- data (:obj:`list`): the list of data used for estimation, with at least \
``obs`` and ``action`` keys.
Effects:
- This is a side effect function which updates the reward values in place.
"""
# NOTE: deepcopy reward part of data is very important,
# otherwise the reward of data in the replay buffer will be incorrectly modified.
train_data_augmented = self.reward_deepcopy(data)
res = collect_states(train_data_augmented)
res = torch.stack(res).to(self.device)
with torch.no_grad():
sum_rewards, sum_abs_rewards = self.reward_model.cum_return(res, mode='batch')
for item, rew in zip(train_data_augmented, sum_rewards): # TODO optimise this loop as well ?
item['reward'] = rew
return train_data_augmented
def collect_data(self, data: list) -> None:
"""
Overview:
Collecting training data formatted by ``fn:concat_state_action_pairs``.
Arguments:
- data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
Effects:
- This is a side effect function which updates the data attribute in ``self``
"""
pass
def clear_data(self) -> None:
"""
Overview:
Clearing training data. \
This is a side effect function which clears the data attribute in ``self``
"""
self.training_obs.clear()
self.training_labels.clear()