From 667fda88ed16dd25be2a79723a71846de3f9bb90 Mon Sep 17 00:00:00 2001 From: Luo Peng Date: Tue, 16 Apr 2024 10:08:13 +0800 Subject: [PATCH] Enhance StructureSystem to achieve higher OCR recognition accuracy (#11916) Closes #10270 and #11665. --- ppstructure/predict_system.py | 125 +++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 46 deletions(-) diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index b8b871689c..8d504ff90c 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -58,7 +58,6 @@ def __init__(self, args): logger.warning( "When args.layout is false, args.ocr is automatically set to false" ) - args.drop_score = 0 # init model self.layout_predictor = None self.text_system = None @@ -93,6 +92,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): 'all': 0 } start = time.time() + if self.image_orientation_predictor is not None: tic = time.time() cls_result = self.image_orientation_predictor.predict( @@ -108,6 +108,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): img = cv2.rotate(img, cv_rotate_code[angle]) toc = time.time() time_dict['image_orientation'] = toc - tic + if self.mode == 'structure': ori_im = img.copy() if self.layout_predictor is not None: @@ -116,6 +117,20 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): else: h, w = ori_im.shape[:2] layout_res = [dict(bbox=None, label='table')] + + # As reported in issues such as #10270 and #11665, the old + # implementation, which recognizes texts from the layout regions, + # has problems with OCR recognition accuracy. + # + # To enhance the OCR recognition accuracy, we implement a patch fix + # that first use text_system to detect and recognize all text information + # and then filter out relevant texts according to the layout regions. + text_res = None + if self.text_system is not None: + text_res, ocr_time_dict = self._predict_text(img) + time_dict['det'] += ocr_time_dict['det'] + time_dict['rec'] += ocr_time_dict['rec'] + res_list = [] for region in layout_res: res = '' @@ -126,6 +141,8 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): else: x1, y1, x2, y2 = 0, 0, w, h roi_img = ori_im + bbox = [x1, y1, x2, y2] + if region['label'] == 'table': if self.table_system is not None: res, table_time_dict = self.table_system( @@ -135,67 +152,83 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): time_dict['det'] += table_time_dict['det'] time_dict['rec'] += table_time_dict['rec'] else: - if self.text_system is not None: - if self.recovery: - wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype) - wht_im[y1:y2, x1:x2, :] = roi_img - filter_boxes, filter_rec_res, ocr_time_dict = self.text_system( - wht_im) - else: - filter_boxes, filter_rec_res, ocr_time_dict = self.text_system( - roi_img) - time_dict['det'] += ocr_time_dict['det'] - time_dict['rec'] += ocr_time_dict['rec'] - - # remove style char, - # when using the recognition model trained on the PubtabNet dataset, - # it will recognize the text format in the table, such as - style_token = [ - '', '', '', '', '', - '', '', '', '', - '', '', '', '', - '' - ] - res = [] - for box, rec_res in zip(filter_boxes, filter_rec_res): - rec_str, rec_conf = rec_res[0], rec_res[1] - for token in style_token: - if token in rec_str: - rec_str = rec_str.replace(token, '') - if not self.recovery: - box += [x1, y1] - if self.return_word_box: - word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2]) - res.append({ - 'text': rec_str, - 'confidence': float(rec_conf), - 'text_region': box.tolist(), - 'text_word': word_box_content_list, - 'text_word_region': word_box_list - }) - else: - res.append({ - 'text': rec_str, - 'confidence': float(rec_conf), - 'text_region': box.tolist() - }) + if text_res is not None: + # Filter the text results whose regions intersect with the current layout bbox. + res = self._filter_text_res(text_res, bbox) + res_list.append({ 'type': region['label'].lower(), - 'bbox': [x1, y1, x2, y2], + 'bbox': bbox, 'img': roi_img, 'res': res, 'img_idx': img_idx }) + end = time.time() time_dict['all'] = end - start return res_list, time_dict + elif self.mode == 'kie': re_res, elapse = self.kie_predictor(img) time_dict['kie'] = elapse time_dict['all'] = elapse return re_res[0], time_dict + return None, None + def _predict_text(self, img): + filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(img) + + # remove style char, + # when using the recognition model trained on the PubtabNet dataset, + # it will recognize the text format in the table, such as + style_token = [ + '', '', '', '', '', + '', '', '', '', + '', '', '', '', + '' + ] + res = [] + for box, rec_res in zip(filter_boxes, filter_rec_res): + rec_str, rec_conf = rec_res[0], rec_res[1] + for token in style_token: + if token in rec_str: + rec_str = rec_str.replace(token, '') + if self.return_word_box: + word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2]) + res.append({ + 'text': rec_str, + 'confidence': float(rec_conf), + 'text_region': box.tolist(), + 'text_word': word_box_content_list, + 'text_word_region': word_box_list + }) + else: + res.append({ + 'text': rec_str, + 'confidence': float(rec_conf), + 'text_region': box.tolist() + }) + return res, ocr_time_dict + + def _filter_text_res(self, text_res, bbox): + res = [] + for r in text_res: + box = r['text_region'] + rect = box[0][0], box[0][1], box[2][0], box[2][1] + if self._has_intersection(bbox, rect): + res.append(r) + return res + + def _has_intersection(self, rect1, rect2): + x_min1, y_min1, x_max1, y_max1 = rect1 + x_min2, y_min2, x_max2, y_max2 = rect2 + if x_min1 > x_max2 or x_max1 < x_min2: + return False + if y_min1 > y_max2 or y_max1 < y_min2: + return False + return True + def save_structure_res(res, save_folder, img_name, img_idx=0): excel_save_folder = os.path.join(save_folder, img_name)