Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish infer_rec and add ic15_dict #4

Merged
merged 9 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
num_workers: 8
img_set_dir: .
label_file_path: ./train_data/hard_label.txt
img_set_dir: ./train_data
label_file_path: ./train_data/rec_gt_train.txt

EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
img_set_dir: .
label_file_path: ./train_data/label_val_all.txt
img_set_dir: ./train_data
label_file_path: ./train_data/rec_gt_test.txt

TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ Global:
epoch_num: 300
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_model_dir: output_ic15
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
character_dict_path: ./ppocr/utils/ic15_dict.txt
loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:

reader_yml: ./configs/rec/rec_icdar15_reader.yml
pretrain_weights: ./pretrain_models/CRNN/best_accuracy
checkpoints:
save_inference_dir:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel

Expand Down
61 changes: 35 additions & 26 deletions ppocr/data/rec/dataset_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import lmdb

from ppocr.utils.utility import initial_logger
from ppocr.utils.utility import get_image_file_list
logger = initial_logger()

from .img_tools import process_image, get_img_data
Expand Down Expand Up @@ -143,8 +144,9 @@ def __init__(self, params):
self.num_workers = 1
else:
self.num_workers = params['num_workers']
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
if params['mode'] != 'test':
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.char_ops = params['char_ops']
self.image_shape = params['image_shape']
self.loss_type = params['loss_type']
Expand All @@ -164,29 +166,34 @@ def __call__(self, process_id):

def sample_iter_reader():
if self.mode == 'test':
print("infer_img:", self.infer_img)
img = cv2.imread(self.infer_img)
norm_img = process_image(img, self.image_shape)
yield norm_img
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
substr = label_infor.decode('utf-8').strip("\n").split("\t")
img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path)
if img is None:
continue
label = substr[1]
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
yield norm_img
else:
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
substr = label_infor.decode('utf-8').strip("\n").split("\t")
img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path)
if img is None:
logger.info("{} does not exist!".format(img_path))
continue
label = substr[1]
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs

def batch_iter_reader():
batch_outs = []
Expand All @@ -198,4 +205,6 @@ def batch_iter_reader():
if len(batch_outs) != 0:
yield batch_outs

return batch_iter_reader
if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
36 changes: 36 additions & 0 deletions ppocr/utils/ic15_dict.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
0
1
2
3
4
5
6
7
8
9
12 changes: 0 additions & 12 deletions set_env.sh

This file was deleted.

2 changes: 1 addition & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main():
metrics = eval_det_run(exe, config, eval_info_dict, "test")
else:
reader_type = config['Global']['reader_yml']
if "chinese" in reader_type:
if "benchmark" not in reader_type:
eval_reader = reader_main(config=config, mode="eval")
eval_info_dict = {'program': eval_program, \
'reader': eval_reader, \
Expand Down
19 changes: 12 additions & 7 deletions tools/infer_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import multiprocessing
import numpy as np


def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
Expand All @@ -47,7 +46,7 @@ def set_paddle_flags(**kwargs):
from ppocr.utils.save_load import init_model
from ppocr.utils.character import CharacterOps
from ppocr.utils.utility import create_module

from ppocr.utils.utility import get_image_file_list
logger = initial_logger()


Expand Down Expand Up @@ -79,9 +78,15 @@ def main():

init_model(config, eval_prog, exe)

blobs = reader_main(config, 'test')
imgs = next(blobs())
for img in imgs:
blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img']
infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
print("infer_img:",infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
Expand All @@ -101,8 +106,8 @@ def main():
preds_text = preds_text.reshape(-1)
preds_text = char_ops.decode(preds_text)

print(preds)
print(preds_text)
print("\t index:",preds)
print("\t word :",preds_text)

# save for inference model
target_var = []
Expand Down