Skip to content

Commit 8479beb

Browse files
authored
Update inference.py
add inference postprocesss and white name list config
1 parent 4e5af07 commit 8479beb

File tree

1 file changed

+173
-9
lines changed

1 file changed

+173
-9
lines changed

Diff for: tools/inference.py

+173-9
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@
77
import sys
88
import argparse
99
import os
10-
1110
import torch
1211
from transformers import BertTokenizer
13-
1412
from tools.bases import args_parse
15-
1613
sys.path.append('..')
17-
1814
from bbcm.modeling.csc import BertForCsc, SoftMaskedBertModel
1915
from bbcm.utils import get_abs_path
20-
16+
import json
17+
import codecs
18+
import re
2119

2220
def parse_args():
2321
parser = argparse.ArgumentParser(description="bbcm")
@@ -85,11 +83,177 @@ def inference(args):
8583
texts.append(line.strip())
8684
else:
8785
texts = args.texts
88-
corrected_texts = model.predict(texts)
89-
print(corrected_texts)
86+
print("传入 的原始文本:{}".format(texts))
87+
corrected_texts = model.predict(texts) # input is list and output is list
88+
print("模型纠错输出文本:{}".format(corrected_texts))
89+
# 输出结果后处理模块
90+
corrected_info = output_result(corrected_texts, sources=texts)
91+
print("模型纠错字段信息:{}".format(corrected_info))
9092
return corrected_texts
9193

94+
def parse_args_test():
95+
parser = argparse.ArgumentParser(description="bbcm")
96+
parser.add_argument(
97+
"--config_file", default="csc/train_SoftMaskedBert.yml", help="config file", type=str
98+
)
99+
parser.add_argument(
100+
"--ckpt_fn", default="epoch=09-val_loss=0.03032.ckpt", help="checkpoint file name", type=str
101+
)
102+
args = parser.parse_args()
103+
return args
104+
105+
106+
def inference_test(texts):
107+
"""input is texts list"""
108+
# 加载推理模型
109+
args = parse_args_test()
110+
# 加载模型参数
111+
model = load_model(args)
112+
#print("传入 的原始文本:{}".format(texts))
113+
corrected_texts = model.predict(texts) # input is list and output is list
114+
#print("模型纠错输出文本:{}".format(corrected_texts))
115+
# 输出结果后处理模块
116+
corrected_info = output_result(corrected_texts, sources=texts)
117+
#print("模型纠错字段信息:{}".format(corrected_info))
118+
return corrected_texts, corrected_info
119+
120+
121+
def load_json(filename, encoding="utf-8"):
122+
"""Load json file"""
123+
if not os.path.exists(filename):
124+
return None
125+
with codecs.open(filename, mode='r', encoding=encoding) as fr:
126+
return json.load(fr)
127+
128+
129+
# 预先加载 - 白名单 - 可根据实际应用场景定向更新后放入此推理代码中备用
130+
white_dict = load_json("../configs/dict/white_name_list.json") # 注意这里的路径-否则white_dict is None
131+
# 编译中文字符
132+
re_han = re.compile("[\u4E00-\u9Fa5]+")
133+
134+
135+
def load_white_dict():
136+
default_lens = 4 # 根据配置的过纠字对应的语义片段长度来设定。默认值,可修改
137+
lens_list = list()
138+
for src in white_dict.keys():
139+
for name in white_dict[src]:
140+
lens_list.append(len(name))
141+
max_lens = max(lens_list) if lens_list else default_lens
142+
return white_dict, max_lens
143+
144+
145+
def output_result(results, sources):
146+
"""
147+
:param results: 模型纠错结果list
148+
:param sources: 输入list
149+
:return:
150+
"""
151+
"""封装输出格式"""
152+
default_data = [
153+
{
154+
"src_sentence": "",
155+
"tgt_sentence": "",
156+
"fragments": []
157+
}
158+
]
159+
if not results:
160+
return default_data
161+
data = []
162+
# 一个result 生成一个字典dict()
163+
for idx, result in enumerate(results):
164+
# 源文本
165+
source = sources[idx]
166+
# 找到diff_info不同的地方
167+
fragments_lst = generate_diff_info(source, result)
168+
dict_res = {
169+
"src_sentence": source,
170+
"tgt_sentence": result,
171+
"fragments": fragments_lst
172+
}
173+
data.append(dict_res)
174+
return data
175+
176+
177+
def generate_diff_info(source, result):
178+
"""
179+
:param source: 原始输入文本 string
180+
:param result: 纠错模型输出文本 string
181+
:return: fragments, 输出[dict_1, dict_2, ....], dict_i 是每个字的纠错输出信息
182+
"""
183+
"""基于原始输入文本和纠错后的文本输出differ_info"""
184+
# 定义默认输出
185+
fragments = list()
186+
# 仅支持输出和输出相同的情况下,如果不同则fragments输出为空
187+
# 后处理逻辑1
188+
if len(source) != len(result):
189+
return fragments
190+
# 后处理逻辑2 - 如果输入的source中没有或仅有一个中文字符则也不处理
191+
res_hans = re_han.findall(source)
192+
if not res_hans:
193+
return fragments
194+
if res_hans and len(res_hans[0]) < 2:
195+
return fragments
196+
# 后处理逻辑3 - 逐个字段比对,输出不同的字的位置
197+
for idx in range(len(source)):
198+
# 原始字
199+
src = source[idx]
200+
# 模型输出的字
201+
tgt = result[idx]
202+
# 如果字没发生变化则按照没有错误处理
203+
if src == tgt:
204+
continue
205+
# 过滤掉非汉字
206+
if not re_han.findall(src):
207+
continue
208+
# 通过白名单过滤掉overcorrection-误杀的情况
209+
if model_white_list_filter(source, src, idx):
210+
continue
211+
212+
# 找到不同的字所在index
213+
fragment = {
214+
"error_init_id": idx, # 出错字开始位置索引
215+
"error_end_id": idx + 1, # 结束索引
216+
"src_fragment": src, # 原字
217+
"tgt_fragment": tgt # 纠正后的字
218+
}
219+
fragments.append(fragment)
220+
return fragments
221+
222+
223+
def model_white_list_filter(source, src, src_idx):
224+
""""source: 原来的句子; texts: 白名单; rules: 白名单规则"""
225+
"""模型输出结果白名单过滤"""
226+
is_correct = False
227+
# 加载白名单
228+
wh_texts, span_w = load_white_dict()
229+
source_lens = len(source)
230+
if src in wh_texts.keys():
231+
for src_span in wh_texts[src]:
232+
# 如果配置的语义片段src_span在 传入的文本text 片段source[span_start:span_end]中,则认为过纠is_correct is True。
233+
span_start = src_idx-span_w
234+
span_end = src_idx+span_w
235+
span_start = 0 if span_start < 0 else span_start
236+
span_end = span_end if span_end < source_lens else source_lens
237+
if src_span in source[span_start:span_end]:
238+
is_correct = True
239+
return is_correct
240+
return is_correct
241+
92242

93243
if __name__ == '__main__':
94-
arguments = parse_args()
95-
inference(arguments)
244+
# 原来推理代码
245+
# arguments = parse_args()
246+
# inference(arguments)
247+
# 添加代码后的测试代码如下:
248+
texts = [
249+
'真麻烦你了。希望你们好好的跳无',
250+
'少先队员因该为老人让坐',
251+
'机七学习是人工智能领遇最能体现智能的一个分知',
252+
'今天心情很好',
253+
'汽车新式在这条路上',
254+
'中国人工只能布局很不错'
255+
]
256+
corrected_texts, corrected_info = inference_test(texts)
257+
for info in corrected_info:
258+
print("----------------------")
259+
print("info:{}".format(info))

0 commit comments

Comments
 (0)