-
Notifications
You must be signed in to change notification settings - Fork 28
/
eval.py
102 lines (87 loc) · 3.14 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import logging
import os
import json
from tqdm import tqdm
from utils import (
evaluate
)
from src.biosyn import (
DictionaryDataset,
QueryDataset,
BioSyn
)
LOGGER = logging.getLogger()
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='BioSyn evaluation')
# Required
parser.add_argument('--model_name_or_path', required=True, help='Directory for model')
parser.add_argument('--dictionary_path', type=str, required=True, help='dictionary path')
parser.add_argument('--data_dir', type=str, required=True, help='data set to evaluate')
# Run settings
parser.add_argument('--use_cuda', action="store_true")
parser.add_argument('--topk', type=int, default=20)
parser.add_argument('--score_mode', type=str, default='hybrid', choices=['hybrid','dense','sparse'])
parser.add_argument('--output_dir', type=str, default='./output/', help='Directory for output')
parser.add_argument('--filter_composite', action="store_true", help="filter out composite mention queries")
parser.add_argument('--filter_duplicate', action="store_true", help="filter out duplicate queries")
parser.add_argument('--save_predictions', action="store_true", help="whether to save predictions")
# Tokenizer settings
parser.add_argument('--max_length', default=25, type=int)
args = parser.parse_args()
return args
def init_logging():
LOGGER.setLevel(logging.INFO)
fmt = logging.Formatter('%(asctime)s: [ %(message)s ]',
'%m/%d/%Y %I:%M:%S %p')
console = logging.StreamHandler()
console.setFormatter(fmt)
LOGGER.addHandler(console)
def load_dictionary(dictionary_path):
dictionary = DictionaryDataset(
dictionary_path = dictionary_path
)
return dictionary.data
def load_queries(data_dir, filter_composite, filter_duplicate):
dataset = QueryDataset(
data_dir=data_dir,
filter_composite=filter_composite,
filter_duplicate=filter_duplicate
)
return dataset.data
def main(args):
init_logging()
print(args)
# load dictionary and data
eval_dictionary = load_dictionary(dictionary_path=args.dictionary_path)
eval_queries = load_queries(
data_dir=args.data_dir,
filter_composite=args.filter_composite,
filter_duplicate=args.filter_duplicate
)
biosyn = BioSyn(
max_length=args.max_length,
use_cuda=args.use_cuda
)
biosyn.load_model(
model_name_or_path=args.model_name_or_path,
)
result_evalset = evaluate(
biosyn=biosyn,
eval_dictionary=eval_dictionary,
eval_queries=eval_queries,
topk=args.topk,
score_mode=args.score_mode
)
LOGGER.info("acc@1={}".format(result_evalset['acc1']))
LOGGER.info("acc@5={}".format(result_evalset['acc5']))
if args.save_predictions:
output_file = os.path.join(args.output_dir,"predictions_eval.json")
with open(output_file, 'w') as f:
json.dump(result_evalset, f, indent=2)
if __name__ == '__main__':
args = parse_args()
main(args)