Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch img inference support for ocr det with readtext_batched #458

Merged
merged 3 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions easyocr/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,55 @@ def copyStateDict(state_dict):
return new_state_dict

def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays
image_arrs = image
else: # image is single numpy array
image_arrs = [image]

img_resized_list = []
# resize
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(image, canvas_size,\
interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio)
for img in image_arrs:
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
interpolation=cv2.INTER_LINEAR,
mag_ratio=mag_ratio)
img_resized_list.append(img_resized)
ratio_h = ratio_w = 1 / target_ratio

# preprocessing
x = normalizeMeanVariance(img_resized)
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
x = np.array([normalizeMeanVariance(n_img) for n_img in img_resized_list])
x = Variable(torch.from_numpy(x).permute(0, 3, 1, 2)) # [b,h,w,c] to [b,c,h,w]
x = x.to(device)

# forward pass
with torch.no_grad():
y, feature = net(x)

# make score and link map
score_text = y[0,:,:,0].cpu().data.numpy()
score_link = y[0,:,:,1].cpu().data.numpy()
boxes_list, polys_list = [], []
for out in y:
# make score and link map
score_text = out[:, :, 0].cpu().data.numpy()
score_link = out[:, :, 1].cpu().data.numpy()

# Post-processing
boxes, polys, mapper = getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
# Post-processing
boxes, polys, mapper = getDetBoxes(
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)

# coordinate adjustment
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
if estimate_num_chars:
boxes = list(boxes)
polys = list(polys)
for k in range(len(polys)):
# coordinate adjustment
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
if estimate_num_chars:
boxes[k] = (boxes[k], mapper[k])
if polys[k] is None: polys[k] = boxes[k]

return boxes, polys

def get_detector(trained_model, device='cpu', quantize=True):
boxes = list(boxes)
polys = list(polys)
for k in range(len(polys)):
if estimate_num_chars:
boxes[k] = (boxes[k], mapper[k])
if polys[k] is None:
polys[k] = boxes[k]
boxes_list.append(boxes)
polys_list.append(polys)

return boxes_list, polys_list

def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
net = CRAFT()

if device == 'cpu':
Expand All @@ -70,21 +83,27 @@ def get_detector(trained_model, device='cpu', quantize=True):
else:
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
net = torch.nn.DataParallel(net).to(device)
cudnn.benchmark = False
cudnn.benchmark = cudnn_benchmark

net.eval()
return net

def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None):
result = []
estimate_num_chars = optimal_num_chars is not None
bboxes, polys = test_net(canvas_size, mag_ratio, detector, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars)

bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
image, text_threshold,
link_threshold, low_text, poly,
device, estimate_num_chars)
if estimate_num_chars:
polys = [p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]

for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
result.append(poly)
polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
for polys in polys_list]

for polys in polys_list:
single_img_result = []
for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
single_img_result.append(poly)
result.append(single_img_result)

return result
81 changes: 63 additions & 18 deletions easyocr/easyocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .recognition import get_recognizer, get_text
from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\
download_and_unzip, printProgressBar, diff, reformat_input,\
make_rotated_img_list, set_result_with_confidence
make_rotated_img_list, set_result_with_confidence,\
reformat_input_batched
from .config import *
from bidi.algorithm import get_display
import numpy as np
Expand All @@ -31,7 +32,7 @@ class Reader(object):
def __init__(self, lang_list, gpu=True, model_storage_directory=None,
user_network_directory=None, recog_network = 'standard',
download_enabled=True, detector=True, recognizer=True,
verbose=True, quantize=True):
verbose=True, quantize=True, cudnn_benchmark=False):
"""Create an EasyOCR Reader.

Parameters:
Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None,
else:
self.device = gpu
self.recognition_models = recognition_models

# check and download detection model
detector_model = 'craft'
corrupt_msg = 'MD5 hash mismatch, possible file corruption'
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None,
dict_list[lang] = os.path.join(BASE_PATH, 'dict', lang + ".txt")

if detector:
self.detector = get_detector(detector_path, self.device, quantize)
self.detector = get_detector(detector_path, self.device, quantize, cudnn_benchmark=cudnn_benchmark)
if recognizer:
if recog_network == 'generation1':
network_params = {
Expand Down Expand Up @@ -271,19 +272,25 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\
if reformat:
img, img_cv_grey = reformat_input(img)

text_box = get_textbox(self.detector, img, canvas_size, mag_ratio,\
text_threshold, link_threshold, low_text,\
False, self.device, optimal_num_chars)
horizontal_list, free_list = group_text_box(text_box, slope_ths,\
ycenter_ths, height_ths,\
width_ths, add_margin, \
(optimal_num_chars is None))

if min_size:
horizontal_list = [i for i in horizontal_list if max(i[1]-i[0],i[3]-i[2]) > min_size]
free_list = [i for i in free_list if max(diff([c[0] for c in i]), diff([c[1] for c in i]))>min_size]

return horizontal_list, free_list
text_box_list = get_textbox(self.detector, img, canvas_size, mag_ratio,
text_threshold, link_threshold, low_text,
False, self.device, optimal_num_chars)

horizontal_list_agg, free_list_agg = [], []
for text_box in text_box_list:
horizontal_list, free_list = group_text_box(text_box, slope_ths,
ycenter_ths, height_ths,
width_ths, add_margin,
(optimal_num_chars is None))
if min_size:
horizontal_list = [i for i in horizontal_list if max(
i[1] - i[0], i[3] - i[2]) > min_size]
free_list = [i for i in free_list if max(
diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size]
horizontal_list_agg.append(horizontal_list)
free_list_agg.append(free_list)

return horizontal_list_agg, free_list_agg

def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
Expand Down Expand Up @@ -381,11 +388,49 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
slope_ths, ycenter_ths,\
height_ths,width_ths,\
add_margin, False)

# get the 1st result from hor & free list as self.detect returns a list of depth 3
horizontal_list, free_list = horizontal_list[0], free_list[0]
result = self.recognize(img_cv_grey, horizontal_list, free_list,\
decoder, beamWidth, batch_size,\
workers, allowlist, blocklist, detail, rotation_info,\
paragraph, contrast_ths, adjust_contrast,\
filter_ths, y_ths, x_ths, False, output_format)

return result

def readtext_batched(self, image, n_width=None, n_height=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None, paragraph = False, min_size = 20,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
canvas_size = 2560, mag_ratio = 1.,\
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, output_format='standard'):
'''
Parameters:
image: file path or numpy-array or a byte stream object
When sending a list of images, they all must of the same size,
the following parameters will automatically resize if they are not None
n_width: int, new width
n_height: int, new height
'''
img, img_cv_grey = reformat_input_batched(image, n_width, n_height)

horizontal_list_agg, free_list_agg = self.detect(img, min_size, text_threshold,\
low_text, link_threshold,\
canvas_size, mag_ratio,\
slope_ths, ycenter_ths,\
height_ths, width_ths,\
add_margin, False)
result_agg = []
# put img_cv_grey in a list if its a single img
img_cv_grey = [img_cv_grey] if len(img_cv_grey.shape) == 2 else img_cv_grey
for grey_img, horizontal_list, free_list in zip(img_cv_grey, horizontal_list_agg, free_list_agg):
result_agg.append(self.recognize(grey_img, horizontal_list, free_list,\
decoder, beamWidth, batch_size,\
workers, allowlist, blocklist, detail, rotation_info,\
paragraph, contrast_ths, adjust_contrast,\
filter_ths, y_ths, x_ths, False, output_format))

return result_agg
29 changes: 29 additions & 0 deletions easyocr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,35 @@ def reformat_input(image):
return img, img_cv_grey


def reformat_input_batched(image, n_width=None, n_height=None):
"""
reformats an image or list of images or a 4D numpy image array &
returns a list of corresponding img, img_cv_grey nd.arrays
image:
[file path, numpy-array, byte stream object,
list of file paths, list of numpy-array, 4D numpy array,
list of byte stream objects]
"""
if ((isinstance(image, np.ndarray) and len(image.shape) == 4) or isinstance(image, list)):
# process image batches if image is list of image np arr, paths, bytes
img, img_cv_grey = [], []
for single_img in image:
clr, gry = reformat_input(single_img)
if n_width is not None and n_height is not None:
clr = cv2.resize(clr, (n_width, n_height))
gry = cv2.resize(gry, (n_width, n_height))
img.append(clr)
img_cv_grey.append(gry)
img, img_cv_grey = np.array(img), np.array(img_cv_grey)
# ragged tensors created when all input imgs are not of the same size
if len(img.shape) == 1 and len(img_cv_grey.shape) == 1:
raise ValueError("The input image array contains images of different sizes. " +
"Please resize all images to same shape or pass n_width, n_height to auto-resize")
else:
img, img_cv_grey = reformat_input(image)
return img, img_cv_grey


def make_rotated_img_list(rotationInfo, img_list):
result_img_list = img_list[:]

Expand Down