-
Notifications
You must be signed in to change notification settings - Fork 0
/
cg_trainer.py
183 lines (159 loc) · 7.07 KB
/
cg_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import copy
import torch
import random
import math
import gc
import time
import pyhocon
import warnings
import numpy as np
import torch.nn as nn
import torch.optim as optim
from utils import *
from constants import *
from models import *
from scripts import benchmark_model
from transformers import *
from data import load_data
from data.base import *
from scorer import evaluate
from argparse import ArgumentParser
PRETRAINED_MODEL = None
#PRETRAINED_MODEL = PRETRAINED_LIGHTWEIGHT_CNN_TEXT_MODEL
def train(configs):
dataset_name = configs['dataset']
# Load dataset
start_time = time.time()
train, dev, test, ontology = load_data(configs['dataset'], configs['use_synthetic_train'])
if dataset_name in [BC5CDR_C, BC5CDR_D, NCBI_D]:
train_ontology = Ontology(join(BASE_ONTOLOGY_DIR, f'{dataset_name}_train.json'))
dev_ontology = Ontology(join(BASE_ONTOLOGY_DIR, f'{dataset_name}_dev.json'))
test_ontology = ontology
else:
train_ontology = dev_ontology = test_ontology = ontology
print('Prepared the dataset (%s seconds)' % (time.time() - start_time))
# Load model
if configs['lightweight']:
print('class LightWeightModel')
model = LightWeightModel(configs)
elif not configs['online_kd']:
print('class DualBertEncodersModel')
model = DualBertEncodersModel(configs)
else:
print('class EncodersModelWithOnlineKD')
model = EncodersModelWithOnlineKD(configs)
print('Prepared the model (Nb params: {})'.format(get_n_params(model)), flush=True)
print(f'Nb tunable params: {get_n_tunable_params(model)}')
# Reload a pretrained model (if exists)
if PRETRAINED_MODEL and os.path.exists(PRETRAINED_MODEL):
print('Reload the pretrained model')
checkpoint = torch.load(PRETRAINED_MODEL, map_location=model.device)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
# Evaluate the initial model on the dev set and the test set
print('Evaluate the initial model on the dev set and the test set')
if dataset_name in [BC5CDR_C, BC5CDR_D, NCBI_D]:
train_ontology.build_index(model, 256)
dev_ontology.build_index(model, 256)
test_ontology.build_index(model, 256)
else:
ontology.build_index(model, 256)
with torch.no_grad():
if configs['hard_negatives_training']:
train_results = evaluate(model, train, train_ontology, configs)
print('Train results: {}'.format(train_results))
dev_results = evaluate(model, dev, dev_ontology, configs)
print('Dev results: {}'.format(dev_results))
test_results = evaluate(model, test, test_ontology, configs)
print('Test results: {}'.format(test_results))
gc.collect()
torch.cuda.empty_cache()
# Prepare the optimizer and the scheduler
optimizer = model.get_optimizer(len(train))
num_epoch_steps = math.ceil(len(train) / configs['batch_size'])
print('Prepared the optimizer and the scheduler', flush=True)
# Start Training
accumulated_loss = RunningAverage()
iters, batch_loss, best_dev_score, final_test_results = 0, 0, 0, None
gradient_accumulation_steps = configs['gradient_accumulation_steps']
for epoch_ix in range(configs['epochs']):
print('Starting epoch {}'.format(epoch_ix+1), flush=True)
for i in range(num_epoch_steps):
iters += 1
instances = train.next_items(configs['batch_size'])
# Compute iter_loss
iter_loss = model(instances, train_ontology, is_training=True)[0]
iter_loss = iter_loss / gradient_accumulation_steps
iter_loss.backward()
batch_loss += iter_loss.data.item()
# Update params
if iters % gradient_accumulation_steps == 0:
accumulated_loss.update(batch_loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), configs['max_grad_norm'])
optimizer.step()
optimizer.zero_grad()
batch_loss = 0
# Report loss
if iters % configs['report_frequency'] == 0:
print('{} Average Loss = {}'.format(iters, round(accumulated_loss(), 3)), flush=True)
accumulated_loss = RunningAverage()
if (epoch_ix + 1) % configs['epoch_evaluation_frequency'] > 0: continue
# Build the index of the ontology
print('Starting building the index of the ontology')
if dataset_name in [BC5CDR_C, BC5CDR_D, NCBI_D]:
train_ontology.build_index(model, 256)
dev_ontology.build_index(model, 256)
test_ontology.build_index(model, 256)
else:
ontology.build_index(model, 256)
# Evaluation after each epoch
with torch.no_grad():
if configs['hard_negatives_training']:
train_results = evaluate(model, train, train_ontology, configs)
print('Train results: {}'.format(train_results))
start_dev_eval_time = time.time()
print('Evaluation on the dev set')
if dataset_name in [BC5CDR_C, BC5CDR_D, NCBI_D] and USE_TRAINDEV:
dev_results = evaluate(model, test, test_ontology, configs)
else:
dev_results = evaluate(model, dev, dev_ontology, configs)
print(dev_results)
dev_score = dev_results['top1_accuracy']
print('Evaluation on the dev set took %s seconds' % (time.time() - start_dev_eval_time))
# if online_kd is enabled
if configs['online_kd']:
print('Evaluation using only the first 3 layers')
model.enable_child_branch_exit(3)
dev_score_3_layers = benchmark_model(model, 128, [configs['dataset']], 'dev')
print(dev_score_3_layers)
model.disable_child_branch_exit()
dev_score = (dev_score + dev_score_3_layers) / 2.0
# Save model if it has better dev score
if dev_score > best_dev_score:
best_dev_score = dev_score
print('Evaluation on the test set')
test_results = evaluate(model, test, test_ontology, configs)
final_test_results = test_results
print(test_results)
# Save the model
save_path = join(configs['save_dir'], 'model.pt')
torch.save({'model_state_dict': model.state_dict()}, save_path)
print('Saved the model', flush=True)
# Free memory of the index
if not dataset_name in [BC5CDR_C, BC5CDR_D, NCBI_D]:
del ontology.namevecs_index
ontology.namevecs_index = None
gc.collect()
torch.cuda.empty_cache()
print(final_test_results)
return final_test_results
if __name__ == "__main__":
# Parse argument
parser = ArgumentParser()
parser.add_argument('-c', '--cg_config', default='lightweight_cnn_text')
parser.add_argument('-d', '--dataset', default=BC5CDR_C, choices=DATASETS)
args = parser.parse_args()
# Prepare config
configs = prepare_configs(args.cg_config, args.dataset)
# Train
train(configs)