-
Notifications
You must be signed in to change notification settings - Fork 1
/
qa.py
143 lines (103 loc) · 5.6 KB
/
qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import Optional, Dict, List
import onnxruntime
import numpy as np
import torch
from transformers import BertTokenizer, BertForQuestionAnswering
class QaInferenceSession(object):
def __init__(self, model_filepath: str, tokenizer_filepath: str) -> None:
self.model_filepath = model_filepath
self.tokenizer_filepath = tokenizer_filepath
def run(self, question: str, text: str) -> str:
raise NotImplementedError
class QaTorchInferenceSession(QaInferenceSession):
def __init__(self,
model_filepath: str,
tokenizer_filepath: str,
device: Optional[str] = "cuda:0") -> None:
super(QaTorchInferenceSession,
self).__init__(model_filepath=model_filepath,
tokenizer_filepath=tokenizer_filepath)
self.model = BertForQuestionAnswering.from_pretrained(
self.model_filepath).eval()
self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_filepath)
self.device = device
self.model.to(self.device)
def prepare_qa_inputs(self, question: str,
text: str) -> Dict[str, torch.Tensor]:
inputs = self.tokenizer(question, text, return_tensors="pt")
if self.device is not None:
inputs_cuda = dict()
for input_name in inputs.keys():
inputs_cuda[input_name] = inputs[input_name].to(self.device)
inputs = inputs_cuda
return inputs
def run(self, question: str, text: str) -> str:
inputs = self.prepare_qa_inputs(question=question, text=text)
all_tokens = self.tokenizer.convert_ids_to_tokens(
inputs["input_ids"].cpu().numpy()[0])
outputs = self.model(**inputs)
start_scores = outputs.start_logits
end_scores = outputs.end_logits
answer_start_idx = torch.argmax(start_scores, 1)[0]
answer_end_idx = torch.argmax(end_scores, 1)[0] + 1
answer = " ".join(all_tokens[answer_start_idx:answer_end_idx])
return answer
class QaOnnxInferenceSession(QaInferenceSession):
def __init__(
self,
model_filepath: str,
tokenizer_filepath: str,
num_intra_op_num_threads: int = 1,
execution_providers: List[str] = [
"CUDAExecutionProvider", "CPUExecutionProvider"
]
) -> None:
super(QaOnnxInferenceSession,
self).__init__(model_filepath=model_filepath,
tokenizer_filepath=tokenizer_filepath)
sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = num_intra_op_num_threads
# Log severity level for a particular Run() invocation. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
sess_options.log_severity_level = 2
# Set graph optimization level
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# To enable model serialization after graph optimization set this
# sess_options.optimized_model_filepath = "<model_output_path/optimized_model.onnx>"
self.num_intra_op_num_threads = num_intra_op_num_threads
self.execution_providers = execution_providers
self.session = onnxruntime.InferenceSession(self.model_filepath,
sess_options)
self.session.set_providers(execution_providers)
self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_filepath)
def run(self, question: str, text: str) -> str:
inputs = self.tokenizer(question, text, return_tensors="np")
all_tokens = self.tokenizer.convert_ids_to_tokens(
inputs["input_ids"][0])
ort_inputs = {
"input_ids": inputs["input_ids"],
"input_mask": inputs["attention_mask"],
"segment_ids": inputs["token_type_ids"],
}
outputs = self.session.run(["start", "end"], ort_inputs)
start_scores = outputs[0]
end_scores = outputs[1]
answer_start_idx = np.argmax(start_scores, 1)[0]
answer_end_idx = np.argmax(end_scores, 1)[0] + 1
answer = " ".join(all_tokens[answer_start_idx:answer_end_idx])
return answer
if __name__ == "__main__":
onnx_model_filepath = "./saved_models/bert-base-cased-squad2_model.onnx"
torch_model_filepath = "./saved_models/bert-base-cased-squad2_model.pt"
tokenizer_filepath = "./saved_models/bert-base-cased-squad2_tokenizer.pt"
onnx_inference_session = QaOnnxInferenceSession(
model_filepath=onnx_model_filepath,
tokenizer_filepath=tokenizer_filepath)
torch_inference_session = QaTorchInferenceSession(
model_filepath=torch_model_filepath,
tokenizer_filepath=tokenizer_filepath)
question = "What publication printed that the wealthiest 1% have more money than those in the bottom 90%?"
text = "According to PolitiFact the top 400 richest Americans \"have more wealth than half of all Americans combined.\" According to the New York Times on July 22, 2014, the \"richest 1 percent in the United States now own more wealth than the bottom 90 percent\". Inherited wealth may help explain why many Americans who have become rich may have had a \"substantial head start\". In September 2012, according to the Institute for Policy Studies, \"over 60 percent\" of the Forbes richest 400 Americans \"grew up in substantial privilege\"."
onnx_answer = onnx_inference_session.run(question=question, text=text)
torch_answer = torch_inference_session.run(question=question, text=text)
print(onnx_answer)
print(torch_answer)