-
Notifications
You must be signed in to change notification settings - Fork 128
/
rocketqa.py
167 lines (141 loc) · 6.71 KB
/
rocketqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import sys
import json
import paddle
import urllib
import numpy as np
import tarfile
import warnings
import hashlib
from tqdm import tqdm
from rocketqa.predict.dual_encoder import DualEncoder
from rocketqa.predict.cross_encoder import CrossEncoder
paddle.enable_static()
warnings.simplefilter('ignore')
__MODELS = {
"v1_marco_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/v1_marco_de.tar.gz", # RocketQA v1 dual-encoder trained on MSMARCO
"v1_marco_ce": "http://rocketqa.bj.bcebos.com/RocketQAModels/v1_marco_ce.tar.gz", # RocketQA v1 cross-encoder trained on MSMARCO
"v1_nq_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/v1_nq_de.tar.gz", # RocketQA v1 dual-encoder trained on Natural Question
"v1_nq_ce": "http://rocketqa.bj.bcebos.com/RocketQAModels/v1_nq_ce.tar.gz", # RocketQA v1 cross-encoder trained on Natural Question
"pair_marco_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/pair_marco_de.tar.gz", # PAIR dual-encoder trained on MSMARCO
"pair_nq_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/pair_nq_de.tar.gz", # PAIR dual-encoder trained on Natural Question
"v2_marco_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/v2_marco_de.tar.gz", # RocketQA v2 dual-encoder trained on MSMARCO
"v2_marco_ce": "http://rocketqa.bj.bcebos.com/RocketQAModels/v2_marco_ce.tar.gz", # RocketQA v2 cross-encoder trained on MSMARCO
"v2_nq_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/v2_nq_de.tar.gz", # RocketQA v2 dual-encoder trained on Natural Question
"zh_dureader_de": "http://rocketqa.bj.bcebos.com/RocketQAModels/zh_dureader_de.tar.gz", # RocketQA zh dual-encoder trained on Dureader
"zh_dureader_ce": "http://rocketqa.bj.bcebos.com/RocketQAModels/zh_dureader_ce.tar.gz" # RocketQA zh cross-encoder trained on Dureader
}
__MODELS_MD5 = {
"v1_marco_de": "d8210e4080935bd7fdad7a394cd60b66",
"v1_marco_ce": "caec5aedc46f22edd7107ecd793fc7fb",
"v1_nq_de": "cfeb70f82087b8a47bb0d6d6cfcd61c5",
"v1_nq_ce": "15aac78d70cc25994016b8a30d80f12c",
"pair_marco_de": "b4080ffa2999525e5ba2aa1f4e03a9e8",
"pair_nq_de": "d770bc379ec6def7e0588ec02c80ace2",
"v2_marco_de": "4ce64ff35d1d831f0ca989e49abde227",
"v2_marco_ce": "915ea7ff214a4a92a3a1e1d56c3fb469",
"v2_nq_de": "8f177aa75cadaad6656dcd981edc983b",
"zh_dureader_de": "673ff667bdb3b315a7e2e1b5624babc4",
"zh_dureader_ce": "5e8e6a026e1cb7600fc7e9250f79beb1"
}
def available_models():
"""
Return the names of available RocketQA models
"""
return __MODELS.keys()
def load_model(model, use_cuda=False, device_id=0, batch_size=1):
"""
Load a RocketQA model or an user-specified checkpoint
Args:
model: A model name return by `rocketqa.available_models()` or the path of an user-specified checkpoint config
use_cuda: Whether to use GPU
device_id: The device to put the model
batch_size: Batch_size during inference
Returns:
model
"""
model_type = ''
model_name = ''
rocketqa_model = False
encoder_conf = {}
if model in __MODELS:
model_name = model
print (f"RocketQA model [{model_name}]", file=sys.stderr)
rocketqa_model = True
model_path = os.path.expanduser('~/.rocketqa/') + model_name + '/'
if not os.path.exists(model_path):
if __download(model_name) is False:
raise Exception(f"RocketQA model [{model_name}] download failed, \
please check model dir [{model_path}]")
encoder_conf['conf_path'] = model_path + 'config.json'
encoder_conf['model_path'] = model_path
if model_name.find("_de") >= 0:
model_type = 'dual_encoder'
elif model_name.find("_ce") >= 0:
model_type = 'cross_encoder'
if rocketqa_model is False:
print ("User-specified model", file=sys.stderr)
conf_path = model
model_name = model
if not os.path.isfile(conf_path):
raise Exception(f"Config file [{conf_path}] not found")
try:
with open(conf_path, 'r', encoding='utf8') as json_file:
config_dict = json.load(json_file)
except Exception as e:
raise Exception(str(e) + f"\nConfig file [{conf_path}] load failed")
encoder_conf['conf_path'] = conf_path
split_p = conf_path.rfind('/')
if split_p > 0:
encoder_conf['model_path'] = conf_path[0:split_p + 1]
if "model_type" not in config_dict:
raise Exception("[model_type] not found in config file")
model_type = config_dict["model_type"]
if model_type != "dual_encoder" and model_type != "cross_encoder":
raise Exception("model_type [model_type] is illegal, must be `dual_encoder` or `cross_encoder`")
encoder_conf["use_cuda"] = use_cuda
encoder_conf["device_id"] = device_id
encoder_conf["batch_size"] = batch_size
encoder_conf["model_name"] = model_name
if model_type[0] == "d":
encoder = DualEncoder(**encoder_conf)
elif model_type[0] == "c":
encoder = CrossEncoder(**encoder_conf)
print ("Load model done", file=sys.stderr)
return encoder
def __download(model_name):
os.makedirs(os.path.expanduser('~/.rocketqa/'), exist_ok=True)
filename = model_name + '.tar.gz'
download_dst = os.path.join(os.path.expanduser('~/.rocketqa/') + filename)
download_url = __MODELS[model_name]
if not os.path.exists(download_dst):
print (f"Download RocketQA model [{model_name}]", file=sys.stderr)
with urllib.request.urlopen(download_url) as source, open(download_dst, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
file_md5= __get_file_md5(download_dst)
if file_md5 != __MODELS_MD5[model_name]:
raise Exception(f"Model file [{download_dst}] exists, but md5 doesnot match")
try:
t = tarfile.open(download_dst)
t.extractall(os.path.expanduser('~/.rocketqa/'))
except Exception as e:
print (str(e), file=sys.stderr)
return False
return True
def __get_file_md5(fname):
m = hashlib.md5()
with open(fname,'rb') as fobj:
while True:
data = fobj.read(4096)
if not data:
break
m.update(data)
return m.hexdigest()
if __name__ == '__main__':
pass