forked from varshakishore/dsi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_acc.py
282 lines (203 loc) · 8.34 KB
/
evaluate_acc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import datasets
from transformers import BertTokenizer
import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm
from BertModel import QueryClassifier
import random
import logging
logger = logging.getLogger(__name__)
import argparse
import os
import joblib
from utils import *
from dsi_model_continual import DSIqgTrainDataset, GenPassageDataset, IndexingCollator
from direct_optimize import initialize_model, initialize_nq320k
from dsi_model_v1 import validate
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_size",
default=2000,
type=int,
required=False,
help="batch_size",
)
parser.add_argument(
"--model_name",
default='bert-base-uncased',
choices=['T5-base', 'bert-base-uncased'],
help="Model name",
)
parser.add_argument(
"--seed",
default=42,
type=int,
required=False,
help="random seed",
)
parser.add_argument(
"--initialize_embeddings",
default=None,
type=str,
help="folder for the embedding matrix",
)
parser.add_argument(
"--initialize_model",
default=None,
type=str,
help="path to saved model",
)
parser.add_argument(
"--eval_step",
default=1000,
type=int,
help="step for evaluation"
)
parser.add_argument(
"--dataset",
default='NQ320k',
choices=['NQ320k', 'MSMARCO'],
help='which dataset to use')
parser.add_argument(
"--write_path",
default=None,
type=str,
help='folder to write the results'
)
args = parser.parse_args()
return args
def loaddataset(args, tokenizer):
seen_gen_qs = datasets.load_dataset(
'json',
data_files=os.path.join(args.data_path, args.doc_type, 'passages_seen.json'),
ignore_verifications=False,
cache_dir='cache'
)['train']
unseen_gen_qs = datasets.load_dataset(
'json',
data_files=os.path.join(args.data_path, args.doc_type, 'passages_unseen.json'),
ignore_verifications=False,
cache_dir='cache'
)['train']
print(f'passages loaded with length {len(seen_gen_qs)}')
print(f'unseen passages loaded with length {len(unseen_gen_qs)}')
val_data = datasets.load_dataset(
'json',
data_files=os.path.join(args.data_path, args.doc_type, 'valqueries.json'),
ignore_verifications=False,
cache_dir='cache'
)['train']
print(f'validation set loaded with length {len(val_data)}')
test_data = datasets.load_dataset(
'json',
data_files=os.path.join(args.data_path, args.doc_type, 'testqueries.json'),
ignore_verifications=False,
cache_dir='cache'
)['train']
print(f'test set loaded with length {len(test_data)}')
print('datasets loaded')
if args.doc_type == "new_docs" or args.doc_type == "old_docs":
doc_class = joblib.load(os.path.join(args.data_path, 'new_docs', 'doc_class.pkl'))
elif args.doc_type == "tune_docs":
doc_class = joblib.load(os.path.join(args.data_path, 'tune_docs', 'doc_class.pkl'))
val_dataset = DSIqgTrainDataset(tokenizer=tokenizer, datadict = val_data, doc_class = doc_class)
test_dataset = DSIqgTrainDataset(tokenizer=tokenizer, datadict = val_data, doc_class = doc_class)
seenq_dataset = GenPassageDataset(tokenizer=tokenizer, datadict = seen_gen_qs, doc_class=doc_class)
unseenq_dataset = GenPassageDataset(tokenizer=tokenizer, datadict = unseen_gen_qs, doc_class=doc_class)
return [val_dataset, test_dataset, seenq_dataset, unseenq_dataset]
def filter_Dataloader(args, old_class_num, class_num, datasets, tokenizer):
new_class_num = class_num - old_class_num
print(f'Old class number: {old_class_num}')
print(f'New class number: {new_class_num}')
print(f'Total class number: {class_num}')
# use the data only till the class_num docs
print('Filtering datasets')
filtered_datasets = [dataset.filter(lambda example: example[1] < class_num) for dataset in datasets]
new_datasets = [dataset.filter(lambda example: example[1] >= old_class_num) for dataset in filtered_datasets]
old_datasets = [dataset.filter(lambda example: example[1] < old_class_num) for dataset in filtered_datasets]
print(f"Filtered set:")
print('\n')
print(f"Old-Val-{len(old_datasets[0])}, New-Val-{len(new_datasets[0])}")
print(f"Old-Test-{len(old_datasets[1])}, New-Test-{len(new_datasets[1])}")
print(f"seen queries from old docs{len(old_datasets[2])}, seen queries from new docs{len(new_datasets[2])}")
print(f"unseen queries from old docs{len(old_datasets[3])}, unseen queries from new docs{len(new_datasets[3])}")
old_dataloaders = [DataLoader(old_data, batch_size=args.batch_size,collate_fn=IndexingCollator(tokenizer,padding='longest'),shuffle=False,drop_last=False)
for old_data in old_datasets]
new_dataloaders = [DataLoader(new_data, batch_size=args.batch_size,collate_fn=IndexingCollator(tokenizer,padding='longest'),shuffle=False,drop_last=False)
for new_data in new_datasets]
return old_dataloaders, new_dataloaders
def Getmodel(args, embedding_matrix, class_num):
print(f'Loading Model and Tokenizer for {args.model_name}')
model = QueryClassifier(class_num)
load_saved_weights(model, args.initialize_model, strict_set=False)
### Use pre_calculated weights to initialize the projection matrix
model.classifier.weight.data = embedding_matrix[:class_num].detach().to('cpu')
device = torch.device("cuda")
model = torch.nn.DataParallel(model)
model.to(device)
print(f'Device: {device}')
print('model loaded')
return model
def main():
args = get_arguments()
if args.model_name == 'bert-base-uncased':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',cache_dir='cache')
if args.dataset == "NQ320k":
old_class_num = 98743
elif args.dataset == "MSMARCO":
old_class_num = 289424
else:
raise ValueError(f'dataset={args.dataset} must be NQ320k or MSMARCO')
datasets = loaddataset(args.data_path, args.doc_type)
embedding_matrix = joblib.load(os.path.join(args.initialize_embeddings, 'classifier_layer.pkl'))
print('query embeddings loaded')
h1 = []
h5 = []
h10 = []
m10 = []
progress = []
for class_num in range(old_class_num, embedding_matrix.shape[0], args.eval_step):
added_num = class_num - old_class_num
progress.append(added_num)
model = Getmodel(args, embedding_matrix, class_num)
old_dataloaders, new_dataloaders = filter_Dataloader(args, old_class_num, class_num, datasets, tokenizer)
dataloaders = old_dataloaders.extend(new_dataloaders)
hits1 = []
hits5 = []
hits10 = []
mrr10 = []
splits = ['val queries for old docs', 'test queries for old docs', 'seen generated queries for old docs','unseen generated queries for old docs',
'val queries for new docs', 'test queries for new docs', 'seen generated queries for new docs','unseen generated queries for new docs']
for i,split in enumerate(dataloaders):
print('*'*100)
print(splits[i])
acc1, acc5, acc10, mrr_10 = validate(args, model, split)
hits1.append(acc1.item())
hits5.append(acc5.item())
hits10.append(acc10.item())
mrr10.append(mrr_10.item())
print(f'At step {added_num}:')
print(splits)
print(f'hits@1: {hits1}')
print(f'hits@5: {hits5}')
print(f'hits@10: {hits1}')
print(f'mrr@10: {mrr10}')
h1.append(hits1)
h5.append(hits5)
h10.append(hits10)
m10.append(mrr10)
if not os.path.exists(args.write_path):
os.mkdir(args.write_path)
with open(os.path.join(args.write_path, 'eval_results.txt'),'w') as results:
results.write(f'Evaluation steps: {progress}' + '\n')
results.write(f'splits: {splits}')
results.write(f'hits@1: {h1}' + '\n')
results.write(f'hits@5: {h5}' + '\n')
results.write(f'hits@10: {h10}' + '\n')
results.write(f'mrr@10: {m10}' + '\n')
print(f'results written.')
if __name__ == "__main__":
main()