-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
129 lines (106 loc) · 4.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from argparse import ArgumentParser
from model import Retriever
from util import interact
from config import RetrieverConfig as Config
from processing import prepare, get_loaders, load_data, encode_plus
from ignite.metrics import Loss, Recall
from ignite.utils import setup_logger
from ignite.engine import Engine, Events
from sentence_transformers import SentenceTransformer
from torch.utils.tensorboard import SummaryWriter
from metrics import RecallAt
from sklearn.neighbors import BallTree
import pickle
import numpy
import os
def run():
writer = SummaryWriter()
CUDA = Config.device
model = Retriever()
print(f'Initializing model on {CUDA}')
model.to(CUDA)
optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR)
loss_fn = torch.nn.L1Loss().to(CUDA)
print(f'Creating sentence transformer')
encoder = SentenceTransformer(Config.sentence_transformer).to(CUDA)
for parameter in encoder.parameters():
parameter.requires_grad = False
print(f'Loading data')
if os.path.exists('_full_dump'):
with open('_full_dump', 'rb') as pin:
train_loader, train_utts, val_loader, val_utts = pickle.load(pin)
else:
data = load_data(Config.data_source)
train_loader, train_utts, val_loader, val_utts = get_loaders(data, encoder, Config.batch_size)
with open('_full_dump', 'wb') as pout:
pickle.dump((train_loader, train_utts, val_loader, val_utts), pout, protocol=-1)
def train_step(engine, batch):
model.train()
optimizer.zero_grad()
x, not_ys, y = batch
yhat = model(x[0])
loss = loss_fn(yhat, y)
gains = loss_fn(not_ys[0], yhat) * Config.negative_weight
loss -= gains
loss.backward()
optimizer.step()
return loss.item()
def eval_step(engine, batch):
model.eval()
with torch.no_grad():
x, _, y = batch
yhat = model(x[0])
return yhat, y
trainer = Engine(train_step)
trainer.logger = setup_logger('trainer')
evaluator = Engine(eval_step)
evaluator.logger = setup_logger('evaluator')
latent_space = BallTree(numpy.array(list(train_utts.keys())))
l1 = Loss(loss_fn)
recall = RecallAt(latent_space)
recall.attach(evaluator, 'recall')
l1.attach(evaluator, 'l1')
@trainer.on(Events.ITERATION_COMPLETED(every=1000))
def log_training(engine):
batch_loss = engine.state.output
lr = optimizer.param_groups[0]['lr']
e = engine.state.epoch
n = engine.state.max_epochs
i = engine.state.iteration
print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))
writer.add_scalar('Training/loss', batch_loss, i)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(f"Training Results - Epoch: {engine.state.epoch} "
f" L1: {metrics['l1']:.2f} "
f" R@1: {metrics['r1']:.2f} "
f" R@3: {metrics['r3']:.2f} "
f" R@10: {metrics['r10']:.2f} ")
for metric, value in metrics.items():
writer.add_scalar(f'Training/{metric}', value, engine.state.epoch)
#@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(f"Validation Results - Epoch: {engine.state.epoch} "
f"L1: {metrics['l1']:.2f} "
f" R@10: {metrics['r10']:.2f} ")
for metric, value in metrics.items():
writer.add_scalar(f'Validation/{metric}', value, engine.state.epoch)
trainer.run(train_loader, max_epochs=Config.max_epochs)
torch.save(model.state_dict(), Config.checkpoint)
print(f'Saved checkpoint at {Config.checkpoint}')
interact(model, encoder, latent_space, train_utts)
if __name__ == '__main__':
paper = ArgumentParser()
for element in dir(Config):
if element.startswith('__'):
continue
paper.add_argument(f'--{element}', default=getattr(Config, element), type=type(getattr(Config, element)))
args = paper.parse_args()
for arg, value in vars(args).items():
setattr(Config, arg, value)
run()