Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed Sep 27, 2021
2 parents 9a44e27 + 1b2ca6e commit 8308f33
Show file tree
Hide file tree
Showing 20 changed files with 1,261 additions and 18 deletions.
106 changes: 106 additions & 0 deletions configs/rec/rec_resnet_stn_bilstm_att.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
Global:
use_gpu: True
epoch_num: 400
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/seed
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: EN_symbol
max_text_length: 100
infer_mode: False
use_space_char: False
eval_filter: True
save_res_path: ./output/rec/predicts_seed.txt


Optimizer:
name: Adadelta
weight_deacy: 0.0
momentum: 0.9
lr:
name: Piecewise
decay_epochs: [4,5,8]
values: [1.0, 0.1, 0.01]
regularizer:
name: 'L2'
factor: 2.0e-05


Architecture:
model_type: seed
algorithm: ASTER
Transform:
name: STN_ON
tps_inputsize: [32, 64]
tps_outputsize: [32, 100]
num_control_points: 20
tps_margins: [0.05,0.05]
stn_activation: none
Backbone:
name: ResNet_ASTER
Head:
name: AsterHead # AttentionHead
sDim: 512
attDim: 512
max_len_labels: 100

Loss:
name: AsterLoss

PostProcess:
name: SEEDLabelDecode

Metric:
name: RecMetric
main_indicator: acc
is_filter: True

Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- Fasttext:
path: "./cc.en.300.bin"
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SEEDLabelEncode: # Class handling label
- SEEDResize:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 256
drop_last: True
num_workers: 6

Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SEEDLabelEncode: # Class handling label
- SEEDResize:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: True
batch_size_per_card: 256
num_workers: 4
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop

from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, SEEDResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *
Expand Down
42 changes: 39 additions & 3 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(self,
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
self.unknown = "UNKNOWN"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(self,
super(NRTRLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)

def __call__(self, data):
text = data['label']
text = self.encode(text)
Expand All @@ -185,10 +187,12 @@ def __call__(self, data):
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data

def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character


class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

Expand Down Expand Up @@ -337,6 +341,39 @@ def get_beg_end_flag_idx(self, beg_or_end):
return idx


class SEEDLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)

def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
return dict_character

def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclue eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
)
data['label'] = np.array(text)
return data


class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

Expand Down Expand Up @@ -416,7 +453,6 @@ def load_char_elem_dict(self, character_dict_path):
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])

for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
Expand Down Expand Up @@ -588,7 +624,7 @@ def __call__(self, data):
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]

padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
Expand Down
15 changes: 14 additions & 1 deletion ppocr/data/imaug/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import six
import cv2
import numpy as np
import fasttext


class DecodeImage(object):
Expand Down Expand Up @@ -83,12 +84,13 @@ def __call__(self, data):
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data


class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
Expand Down Expand Up @@ -133,6 +135,17 @@ def __call__(self, data):
return data


class Fasttext(object):
def __init__(self, path="None", **kwargs):
self.fast_model = fasttext.load_model(path)

def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data


class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
Expand Down
26 changes: 25 additions & 1 deletion ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def __call__(self, data):
return data


class SEEDResize(object):
def __init__(self, image_shape, infer_mode=False, **kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode

def __call__(self, data):
img = data['image']
norm_img = resize_no_padding_img(img, self.image_shape)
data['image'] = norm_img
return data


class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
Expand Down Expand Up @@ -109,7 +121,8 @@ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):

def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio)
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
Expand Down Expand Up @@ -175,6 +188,17 @@ def resize_norm_img(img, image_shape):
return padding_im


def resize_no_padding_img(img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image


def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
Expand Down
6 changes: 5 additions & 1 deletion ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@
# table loss
from .table_att_loss import TableAttentionLoss

from .rec_aster_loss import AsterLoss


def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss',
'SARLoss', 'AsterLoss'
]

config = copy.deepcopy(config)
Expand Down
Loading

0 comments on commit 8308f33

Please sign in to comment.