|
7 | 7 | import sys
|
8 | 8 | import argparse
|
9 | 9 | import os
|
10 |
| - |
11 | 10 | import torch
|
12 | 11 | from transformers import BertTokenizer
|
13 |
| - |
14 | 12 | from tools.bases import args_parse
|
15 |
| - |
16 | 13 | sys.path.append('..')
|
17 |
| - |
18 | 14 | from bbcm.modeling.csc import BertForCsc, SoftMaskedBertModel
|
19 | 15 | from bbcm.utils import get_abs_path
|
20 |
| - |
| 16 | +import json |
| 17 | +import codecs |
| 18 | +import re |
21 | 19 |
|
22 | 20 | def parse_args():
|
23 | 21 | parser = argparse.ArgumentParser(description="bbcm")
|
@@ -85,11 +83,177 @@ def inference(args):
|
85 | 83 | texts.append(line.strip())
|
86 | 84 | else:
|
87 | 85 | texts = args.texts
|
88 |
| - corrected_texts = model.predict(texts) |
89 |
| - print(corrected_texts) |
| 86 | + print("传入 的原始文本:{}".format(texts)) |
| 87 | + corrected_texts = model.predict(texts) # input is list and output is list |
| 88 | + print("模型纠错输出文本:{}".format(corrected_texts)) |
| 89 | + # 输出结果后处理模块 |
| 90 | + corrected_info = output_result(corrected_texts, sources=texts) |
| 91 | + print("模型纠错字段信息:{}".format(corrected_info)) |
90 | 92 | return corrected_texts
|
91 | 93 |
|
| 94 | +def parse_args_test(): |
| 95 | + parser = argparse.ArgumentParser(description="bbcm") |
| 96 | + parser.add_argument( |
| 97 | + "--config_file", default="csc/train_SoftMaskedBert.yml", help="config file", type=str |
| 98 | + ) |
| 99 | + parser.add_argument( |
| 100 | + "--ckpt_fn", default="epoch=09-val_loss=0.03032.ckpt", help="checkpoint file name", type=str |
| 101 | + ) |
| 102 | + args = parser.parse_args() |
| 103 | + return args |
| 104 | + |
| 105 | + |
| 106 | +def inference_test(texts): |
| 107 | + """input is texts list""" |
| 108 | + # 加载推理模型 |
| 109 | + args = parse_args_test() |
| 110 | + # 加载模型参数 |
| 111 | + model = load_model(args) |
| 112 | + #print("传入 的原始文本:{}".format(texts)) |
| 113 | + corrected_texts = model.predict(texts) # input is list and output is list |
| 114 | + #print("模型纠错输出文本:{}".format(corrected_texts)) |
| 115 | + # 输出结果后处理模块 |
| 116 | + corrected_info = output_result(corrected_texts, sources=texts) |
| 117 | + #print("模型纠错字段信息:{}".format(corrected_info)) |
| 118 | + return corrected_texts, corrected_info |
| 119 | + |
| 120 | + |
| 121 | +def load_json(filename, encoding="utf-8"): |
| 122 | + """Load json file""" |
| 123 | + if not os.path.exists(filename): |
| 124 | + return None |
| 125 | + with codecs.open(filename, mode='r', encoding=encoding) as fr: |
| 126 | + return json.load(fr) |
| 127 | + |
| 128 | + |
| 129 | +# 预先加载 - 白名单 - 可根据实际应用场景定向更新后放入此推理代码中备用 |
| 130 | +white_dict = load_json("../configs/dict/white_name_list.json") # 注意这里的路径-否则white_dict is None |
| 131 | +# 编译中文字符 |
| 132 | +re_han = re.compile("[\u4E00-\u9Fa5]+") |
| 133 | + |
| 134 | + |
| 135 | +def load_white_dict(): |
| 136 | + default_lens = 4 # 根据配置的过纠字对应的语义片段长度来设定。默认值,可修改 |
| 137 | + lens_list = list() |
| 138 | + for src in white_dict.keys(): |
| 139 | + for name in white_dict[src]: |
| 140 | + lens_list.append(len(name)) |
| 141 | + max_lens = max(lens_list) if lens_list else default_lens |
| 142 | + return white_dict, max_lens |
| 143 | + |
| 144 | + |
| 145 | +def output_result(results, sources): |
| 146 | + """ |
| 147 | + :param results: 模型纠错结果list |
| 148 | + :param sources: 输入list |
| 149 | + :return: |
| 150 | + """ |
| 151 | + """封装输出格式""" |
| 152 | + default_data = [ |
| 153 | + { |
| 154 | + "src_sentence": "", |
| 155 | + "tgt_sentence": "", |
| 156 | + "fragments": [] |
| 157 | + } |
| 158 | + ] |
| 159 | + if not results: |
| 160 | + return default_data |
| 161 | + data = [] |
| 162 | + # 一个result 生成一个字典dict() |
| 163 | + for idx, result in enumerate(results): |
| 164 | + # 源文本 |
| 165 | + source = sources[idx] |
| 166 | + # 找到diff_info不同的地方 |
| 167 | + fragments_lst = generate_diff_info(source, result) |
| 168 | + dict_res = { |
| 169 | + "src_sentence": source, |
| 170 | + "tgt_sentence": result, |
| 171 | + "fragments": fragments_lst |
| 172 | + } |
| 173 | + data.append(dict_res) |
| 174 | + return data |
| 175 | + |
| 176 | + |
| 177 | +def generate_diff_info(source, result): |
| 178 | + """ |
| 179 | + :param source: 原始输入文本 string |
| 180 | + :param result: 纠错模型输出文本 string |
| 181 | + :return: fragments, 输出[dict_1, dict_2, ....], dict_i 是每个字的纠错输出信息 |
| 182 | + """ |
| 183 | + """基于原始输入文本和纠错后的文本输出differ_info""" |
| 184 | + # 定义默认输出 |
| 185 | + fragments = list() |
| 186 | + # 仅支持输出和输出相同的情况下,如果不同则fragments输出为空 |
| 187 | + # 后处理逻辑1 |
| 188 | + if len(source) != len(result): |
| 189 | + return fragments |
| 190 | + # 后处理逻辑2 - 如果输入的source中没有或仅有一个中文字符则也不处理 |
| 191 | + res_hans = re_han.findall(source) |
| 192 | + if not res_hans: |
| 193 | + return fragments |
| 194 | + if res_hans and len(res_hans[0]) < 2: |
| 195 | + return fragments |
| 196 | + # 后处理逻辑3 - 逐个字段比对,输出不同的字的位置 |
| 197 | + for idx in range(len(source)): |
| 198 | + # 原始字 |
| 199 | + src = source[idx] |
| 200 | + # 模型输出的字 |
| 201 | + tgt = result[idx] |
| 202 | + # 如果字没发生变化则按照没有错误处理 |
| 203 | + if src == tgt: |
| 204 | + continue |
| 205 | + # 过滤掉非汉字 |
| 206 | + if not re_han.findall(src): |
| 207 | + continue |
| 208 | + # 通过白名单过滤掉overcorrection-误杀的情况 |
| 209 | + if model_white_list_filter(source, src, idx): |
| 210 | + continue |
| 211 | + |
| 212 | + # 找到不同的字所在index |
| 213 | + fragment = { |
| 214 | + "error_init_id": idx, # 出错字开始位置索引 |
| 215 | + "error_end_id": idx + 1, # 结束索引 |
| 216 | + "src_fragment": src, # 原字 |
| 217 | + "tgt_fragment": tgt # 纠正后的字 |
| 218 | + } |
| 219 | + fragments.append(fragment) |
| 220 | + return fragments |
| 221 | + |
| 222 | + |
| 223 | +def model_white_list_filter(source, src, src_idx): |
| 224 | + """"source: 原来的句子; texts: 白名单; rules: 白名单规则""" |
| 225 | + """模型输出结果白名单过滤""" |
| 226 | + is_correct = False |
| 227 | + # 加载白名单 |
| 228 | + wh_texts, span_w = load_white_dict() |
| 229 | + source_lens = len(source) |
| 230 | + if src in wh_texts.keys(): |
| 231 | + for src_span in wh_texts[src]: |
| 232 | + # 如果配置的语义片段src_span在 传入的文本text 片段source[span_start:span_end]中,则认为过纠is_correct is True。 |
| 233 | + span_start = src_idx-span_w |
| 234 | + span_end = src_idx+span_w |
| 235 | + span_start = 0 if span_start < 0 else span_start |
| 236 | + span_end = span_end if span_end < source_lens else source_lens |
| 237 | + if src_span in source[span_start:span_end]: |
| 238 | + is_correct = True |
| 239 | + return is_correct |
| 240 | + return is_correct |
| 241 | + |
92 | 242 |
|
93 | 243 | if __name__ == '__main__':
|
94 |
| - arguments = parse_args() |
95 |
| - inference(arguments) |
| 244 | + # 原来推理代码 |
| 245 | + # arguments = parse_args() |
| 246 | + # inference(arguments) |
| 247 | + # 添加代码后的测试代码如下: |
| 248 | + texts = [ |
| 249 | + '真麻烦你了。希望你们好好的跳无', |
| 250 | + '少先队员因该为老人让坐', |
| 251 | + '机七学习是人工智能领遇最能体现智能的一个分知', |
| 252 | + '今天心情很好', |
| 253 | + '汽车新式在这条路上', |
| 254 | + '中国人工只能布局很不错' |
| 255 | + ] |
| 256 | + corrected_texts, corrected_info = inference_test(texts) |
| 257 | + for info in corrected_info: |
| 258 | + print("----------------------") |
| 259 | + print("info:{}".format(info)) |
0 commit comments