-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathget_mediqa_gt_processed.py
113 lines (104 loc) · 3.8 KB
/
get_mediqa_gt_processed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import xml.etree.ElementTree as ET
from data_utils.mediqa_utils import submit
import os
import pdb
import json
is_train=False
dir_path='../data/mediqa/task1_mednli/'
processed_path = '../data/mediqa_processed/mt_dnn_mediqa_scibert_v2/'
if is_train:
dev_path=os.path.join(processed_path,'mednli_train.json')
else:
dev_path=os.path.join(processed_path,'mednli_dev.json')
uids=[]
preds=[]
with open(dev_path,encoding='utf-8') as f:
for line in f:
sample=json.loads(line)
uids.append(sample['uid'])
preds.append(sample['label'])
output_path=os.path.join(dir_path,'gt_train.csv') if is_train else os.path.join(dir_path,'gt_dev.csv')
result={'uids':uids,'predictions':preds}
submit(output_path, result, 'mednli')
dir_path='../data/mediqa/task2_rqe/'
processed_path = '../data/mediqa_processed/mt_dnn_mediqa_scibert_v2/'
if is_train:
dev_path=os.path.join(processed_path,'rqe_train.json')
else:
dev_path=os.path.join(processed_path,'rqe_dev.json')
uids=[]
preds=[]
with open(dev_path,encoding='utf-8') as f:
for line in f:
sample=json.loads(line)
uids.append(sample['uid'])
preds.append(sample['label'])
output_path=os.path.join(dir_path,'gt_train.csv') if is_train else os.path.join(dir_path,'gt_dev.csv')
result={'uids':uids,'predictions':preds}
submit(output_path, result, 'rqe')
dataset_name='rqe_shuff'
dir_path='../data/mediqa/task2_rqe/'
processed_path = '../data/mediqa_processed/mt_dnn_mediqa_scibert_v2/'
if is_train:
dev_path=os.path.join(processed_path,'{}_train.json'.format(dataset_name))
else:
dev_path=os.path.join(processed_path,'{}_dev.json'.format(dataset_name))
uids=[]
preds=[]
with open(dev_path,encoding='utf-8') as f:
for line in f:
sample=json.loads(line)
uids.append(sample['uid'])
preds.append(sample['label'])
output_path=os.path.join(dir_path,'gt_train_{}.csv'.format(dataset_name)) if is_train else os.path.join(dir_path,'gt_dev_{}.csv'.format(dataset_name))
result={'uids':uids,'predictions':preds}
submit(output_path, result, 'rqe')
dir_path='../data/mediqa/MedQuAD/'
processed_path = '../data/mediqa_processed/mt_dnn_mediqa_scibert_v2/'
if is_train:
dev_path=os.path.join(processed_path,'medquad_train.json')
else:
dev_path=os.path.join(processed_path,'medquad_dev.json')
uids=[]
preds=[]
with open(dev_path,encoding='utf-8') as f:
for line in f:
sample=json.loads(line)
uids.append(sample['uid'])
preds.append(sample['label'])
output_path=os.path.join(dir_path,'gt_train.csv') if is_train else os.path.join(dir_path,'gt_dev.csv')
result={'uids':uids,'predictions':preds}
submit(output_path, result, 'medquad')
dir_path='../data/mediqa/task3_qa/'
processed_path = '../data/mediqa_processed/mt_dnn_mediqa_scibert_v2/'
if is_train:
dev_path=os.path.join(processed_path,'mediqa_train.json')
else:
dev_path=os.path.join(processed_path,'mediqa_dev.json')
uids=[]
scores=[]
with open(dev_path,encoding='utf-8') as f:
for line in f:
sample=json.loads(line)
uids.append(sample['uid'])
scores.append(sample['label'])
output_path=os.path.join(dir_path,'gt_train.csv') if is_train else os.path.join(dir_path,'gt_dev.csv')
result={'uids':uids,'scores':scores}
submit(output_path, result, 'mediqa', threshold=2.000001)
dir_path='../data/mediqa/task3_qa/'
processed_path = '../data/mediqa_processed/mt_dnn_mediqa_scibert_v2/'
for sidx in range(0,5):
if is_train:
dev_path=os.path.join(processed_path,'mediqa_{}_train.json'.format(sidx))
else:
dev_path=os.path.join(processed_path,'mediqa_{}_dev.json'.format(sidx))
uids=[]
scores=[]
with open(dev_path,encoding='utf-8') as f:
for line in f:
sample=json.loads(line)
uids.append(sample['uid'])
scores.append(sample['label'])
output_path=os.path.join(dir_path,'gt_train_{}.csv'.format(sidx)) if is_train else os.path.join(dir_path,'gt_dev_{}.csv'.format(sidx))
result={'uids':uids,'scores':scores}
submit(output_path, result, 'mediqa', threshold=2.000001)