Skip to content

Commit

Permalink
Enhance StructureSystem to achieve higher OCR recognition accuracy (#…
Browse files Browse the repository at this point in the history
…11916)

Closes #10270 and #11665.
  • Loading branch information
RussellLuo authored Apr 16, 2024
1 parent 2965012 commit 667fda8
Showing 1 changed file with 79 additions and 46 deletions.
125 changes: 79 additions & 46 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, args):
logger.warning(
"When args.layout is false, args.ocr is automatically set to false"
)
args.drop_score = 0
# init model
self.layout_predictor = None
self.text_system = None
Expand Down Expand Up @@ -93,6 +92,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
'all': 0
}
start = time.time()

if self.image_orientation_predictor is not None:
tic = time.time()
cls_result = self.image_orientation_predictor.predict(
Expand All @@ -108,6 +108,7 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
img = cv2.rotate(img, cv_rotate_code[angle])
toc = time.time()
time_dict['image_orientation'] = toc - tic

if self.mode == 'structure':
ori_im = img.copy()
if self.layout_predictor is not None:
Expand All @@ -116,6 +117,20 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
else:
h, w = ori_im.shape[:2]
layout_res = [dict(bbox=None, label='table')]

# As reported in issues such as #10270 and #11665, the old
# implementation, which recognizes texts from the layout regions,
# has problems with OCR recognition accuracy.
#
# To enhance the OCR recognition accuracy, we implement a patch fix
# that first use text_system to detect and recognize all text information
# and then filter out relevant texts according to the layout regions.
text_res = None
if self.text_system is not None:
text_res, ocr_time_dict = self._predict_text(img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']

res_list = []
for region in layout_res:
res = ''
Expand All @@ -126,6 +141,8 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
else:
x1, y1, x2, y2 = 0, 0, w, h
roi_img = ori_im
bbox = [x1, y1, x2, y2]

if region['label'] == 'table':
if self.table_system is not None:
res, table_time_dict = self.table_system(
Expand All @@ -135,67 +152,83 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict['det'] += table_time_dict['det']
time_dict['rec'] += table_time_dict['rec']
else:
if self.text_system is not None:
if self.recovery:
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
wht_im[y1:y2, x1:x2, :] = roi_img
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
wht_im)
else:
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
roi_img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']

# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as <b>
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>',
'</overline>', '<underline>', '</underline>', '<i>',
'</i>'
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if not self.recovery:
box += [x1, y1]
if self.return_word_box:
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist(),
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
if text_res is not None:
# Filter the text results whose regions intersect with the current layout bbox.
res = self._filter_text_res(text_res, bbox)

res_list.append({
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
'bbox': bbox,
'img': roi_img,
'res': res,
'img_idx': img_idx
})

end = time.time()
time_dict['all'] = end - start
return res_list, time_dict

elif self.mode == 'kie':
re_res, elapse = self.kie_predictor(img)
time_dict['kie'] = elapse
time_dict['all'] = elapse
return re_res[0], time_dict

return None, None

def _predict_text(self, img):
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(img)

# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as <b>
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>',
'</overline>', '<underline>', '</underline>', '<i>',
'</i>'
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if self.return_word_box:
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist(),
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
return res, ocr_time_dict

def _filter_text_res(self, text_res, bbox):
res = []
for r in text_res:
box = r['text_region']
rect = box[0][0], box[0][1], box[2][0], box[2][1]
if self._has_intersection(bbox, rect):
res.append(r)
return res

def _has_intersection(self, rect1, rect2):
x_min1, y_min1, x_max1, y_max1 = rect1
x_min2, y_min2, x_max2, y_max2 = rect2
if x_min1 > x_max2 or x_max1 < x_min2:
return False
if y_min1 > y_max2 or y_max1 < y_min2:
return False
return True


def save_structure_res(res, save_folder, img_name, img_idx=0):
excel_save_folder = os.path.join(save_folder, img_name)
Expand Down

0 comments on commit 667fda8

Please sign in to comment.