forked from enoche/MMRec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
332 lines (279 loc) · 14 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# coding: utf-8
# @email: enoche.chow@gmail.com
r"""
################################
"""
import os
import itertools
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt
from time import time
from logging import getLogger
from utils.utils import get_local_time, early_stopping, dict2str
from utils.topk_evaluator import TopKEvaluator
class AbstractTrainer(object):
r"""Trainer Class is used to manage the training and evaluation processes of recommender system models.
AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according
to different training and evaluation strategies.
"""
def __init__(self, config, model):
self.config = config
self.model = model
def fit(self, train_data):
r"""Train the model based on the train data.
"""
raise NotImplementedError('Method [next] should be implemented.')
def evaluate(self, eval_data):
r"""Evaluate the model based on the eval data.
"""
raise NotImplementedError('Method [next] should be implemented.')
class Trainer(AbstractTrainer):
r"""The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common
functions for training and evaluation processes of most recommender system models, including fit(), evaluate(),
and some other features helpful for model training and evaluation.
Generally speaking, this class can serve most recommender system models, If the training process of the model is to
simply optimize a single loss without involving any complex training strategies, such as adversarial learning,
pre-training and so on.
Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information
for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on.
More information can be found in [placeholder]. `model` is the instantiated object of a Model Class.
"""
def __init__(self, config, model, mg=False):
super(Trainer, self).__init__(config, model)
self.logger = getLogger()
self.learner = config['learner']
self.learning_rate = config['learning_rate']
self.epochs = config['epochs']
self.eval_step = min(config['eval_step'], self.epochs)
self.stopping_step = config['stopping_step']
self.clip_grad_norm = config['clip_grad_norm']
self.valid_metric = config['valid_metric'].lower()
self.valid_metric_bigger = config['valid_metric_bigger']
self.test_batch_size = config['eval_batch_size']
self.device = config['device']
self.weight_decay = 0.0
if config['weight_decay'] is not None:
wd = config['weight_decay']
self.weight_decay = eval(wd) if isinstance(wd, str) else wd
self.req_training = config['req_training']
self.start_epoch = 0
self.cur_step = 0
tmp_dd = {}
for j, k in list(itertools.product(config['metrics'], config['topk'])):
tmp_dd[f'{j.lower()}@{k}'] = 0.0
self.best_valid_score = -1
self.best_valid_result = tmp_dd
self.best_test_upon_valid = tmp_dd
self.train_loss_dict = dict()
self.optimizer = self._build_optimizer()
#fac = lambda epoch: 0.96 ** (epoch / 50)
lr_scheduler = config['learning_rate_scheduler'] # check zero?
fac = lambda epoch: lr_scheduler[0] ** (epoch / lr_scheduler[1])
scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
self.lr_scheduler = scheduler
self.eval_type = config['eval_type']
self.evaluator = TopKEvaluator(config)
self.item_tensor = None
self.tot_item_num = None
self.mg = mg
self.alpha1 = config['alpha1']
self.alpha2 = config['alpha2']
self.beta = config['beta']
def _build_optimizer(self):
r"""Init the Optimizer
Returns:
torch.optim: the optimizer
"""
if self.learner.lower() == 'adam':
optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'sgd':
optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'adagrad':
optimizer = optim.Adagrad(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.learner.lower() == 'rmsprop':
optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
else:
self.logger.warning('Received unrecognized optimizer, set default Adam optimizer')
optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
return optimizer
def _train_epoch(self, train_data, epoch_idx, loss_func=None):
r"""Train the model in an epoch
Args:
train_data (DataLoader): The train data.
epoch_idx (int): The current epoch id.
loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
:attr:`self.model.calculate_loss`. Defaults to ``None``.
Returns:
float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
multiple parts and the model return these multiple parts loss instead of the sum of loss, It will return a
tuple which includes the sum of loss in each part.
"""
if not self.req_training:
return 0.0, []
self.model.train()
loss_func = loss_func or self.model.calculate_loss
total_loss = None
loss_batches = []
for batch_idx, interaction in enumerate(train_data):
self.optimizer.zero_grad()
second_inter = interaction.clone()
losses = loss_func(interaction)
if isinstance(losses, tuple):
loss = sum(losses)
loss_tuple = tuple(per_loss.item() for per_loss in losses)
total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple)))
else:
loss = losses
total_loss = losses.item() if total_loss is None else total_loss + losses.item()
if self._check_nan(loss):
self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
return loss, torch.tensor(0.0)
if self.mg and batch_idx % self.beta == 0:
first_loss = self.alpha1 * loss
first_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
losses = loss_func(second_inter)
if isinstance(losses, tuple):
loss = sum(losses)
else:
loss = losses
if self._check_nan(loss):
self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
return loss, torch.tensor(0.0)
second_loss = -1 * self.alpha2 * loss
second_loss.backward()
else:
loss.backward()
if self.clip_grad_norm:
clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
self.optimizer.step()
loss_batches.append(loss.detach())
# for test
#if batch_idx == 0:
# break
return total_loss, loss_batches
def _valid_epoch(self, valid_data):
r"""Valid the model with valid data
Args:
valid_data (DataLoader): the valid data
Returns:
float: valid score
dict: valid result
"""
valid_result = self.evaluate(valid_data)
valid_score = valid_result[self.valid_metric] if self.valid_metric else valid_result['NDCG@20']
return valid_score, valid_result
def _check_nan(self, loss):
if torch.isnan(loss):
#raise ValueError('Training loss is nan')
return True
def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
train_loss_output = 'epoch %d training [time: %.2fs, ' % (epoch_idx, e_time - s_time)
if isinstance(losses, tuple):
train_loss_output = ', '.join('train_loss%d: %.4f' % (idx + 1, loss) for idx, loss in enumerate(losses))
else:
train_loss_output += 'train loss: %.4f' % losses
return train_loss_output + ']'
def fit(self, train_data, valid_data=None, test_data=None, saved=False, verbose=True):
r"""Train the model based on the train data and the valid data.
Args:
train_data (DataLoader): the train data
valid_data (DataLoader, optional): the valid data, default: None.
If it's None, the early_stopping is invalid.
test_data (DataLoader, optional): None
verbose (bool, optional): whether to write training and evaluation information to logger, default: True
saved (bool, optional): whether to save the model parameters, default: True
Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
"""
for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
self.model.pre_epoch_processing()
train_loss, _ = self._train_epoch(train_data, epoch_idx)
if torch.is_tensor(train_loss):
# get nan loss
break
#for param_group in self.optimizer.param_groups:
# print('======lr: ', param_group['lr'])
self.lr_scheduler.step()
self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
training_end_time = time()
train_loss_output = \
self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
post_info = self.model.post_epoch_processing()
if verbose:
self.logger.info(train_loss_output)
if post_info is not None:
self.logger.info(post_info)
# eval: To ensure the test result is the best model under validation data, set self.eval_step == 1
if (epoch_idx + 1) % self.eval_step == 0:
valid_start_time = time()
valid_score, valid_result = self._valid_epoch(valid_data)
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
valid_score, self.best_valid_score, self.cur_step,
max_step=self.stopping_step, bigger=self.valid_metric_bigger)
valid_end_time = time()
valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \
(epoch_idx, valid_end_time - valid_start_time, valid_score)
valid_result_output = 'valid result: \n' + dict2str(valid_result)
# test
_, test_result = self._valid_epoch(test_data)
if verbose:
self.logger.info(valid_score_output)
self.logger.info(valid_result_output)
self.logger.info('test result: \n' + dict2str(test_result))
if update_flag:
update_output = '██ ' + self.config['model'] + '--Best validation results updated!!!'
if verbose:
self.logger.info(update_output)
self.best_valid_result = valid_result
self.best_test_upon_valid = test_result
if stop_flag:
stop_output = '+++++Finished training, best eval result in epoch %d' % \
(epoch_idx - self.cur_step * self.eval_step)
if verbose:
self.logger.info(stop_output)
break
return self.best_valid_score, self.best_valid_result, self.best_test_upon_valid
@torch.no_grad()
def evaluate(self, eval_data, is_test=False, idx=0):
r"""Evaluate the model based on the eval data.
Returns:
dict: eval result, key is the eval metric and value in the corresponding metric value
"""
self.model.eval()
# batch full users
batch_matrix_list = []
for batch_idx, batched_data in enumerate(eval_data):
# predict: interaction without item ids
scores = self.model.full_sort_predict(batched_data)
masked_items = batched_data[1]
# mask out pos items
scores[masked_items[0], masked_items[1]] = -1e10
# rank and get top-k
_, topk_index = torch.topk(scores, max(self.config['topk']), dim=-1) # nusers x topk
batch_matrix_list.append(topk_index)
return self.evaluator.evaluate(batch_matrix_list, eval_data, is_test=is_test, idx=idx)
def plot_train_loss(self, show=True, save_path=None):
r"""Plot the train loss in each epoch
Args:
show (bool, optional): whether to show this figure, default: True
save_path (str, optional): the data path to save the figure, default: None.
If it's None, it will not be saved.
"""
epochs = list(self.train_loss_dict.keys())
epochs.sort()
values = [float(self.train_loss_dict[epoch]) for epoch in epochs]
plt.plot(epochs, values)
plt.xticks(epochs)
plt.xlabel('Epoch')
plt.ylabel('Loss')
if show:
plt.show()
if save_path:
plt.savefig(save_path)