Skip to content

Commit

Permalink
merge release/2.6.1 to main (#13523)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Jul 29, 2024
1 parent 1923008 commit 6c12df4
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 15 deletions.
6 changes: 5 additions & 1 deletion ppocr/modeling/backbones/det_mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppocr.modeling.backbones.rec_hgnet import MeanPool2D

__all__ = ["MobileNetV3"]

Expand Down Expand Up @@ -260,7 +261,10 @@ def forward(self, inputs):
class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
if "npu" in paddle.device.get_device():
self.avg_pool = MeanPool2D(1, 1)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
Expand Down
19 changes: 18 additions & 1 deletion ppocr/modeling/backbones/rec_hgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
ones_ = Constant(value=1.0)


class MeanPool2D(nn.Layer):
def __init__(self, w, h):
super().__init__()
self.w = w
self.h = h

def forward(self, feat):
batch_size, channels, _, _ = feat.shape
feat_flat = paddle.reshape(feat, [batch_size, channels, -1])
feat_mean = paddle.mean(feat_flat, axis=2)
feat_mean = paddle.reshape(feat_mean, [batch_size, channels, self.w, self.h])
return feat_mean


class ConvBNAct(nn.Layer):
def __init__(
self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
Expand Down Expand Up @@ -59,7 +73,10 @@ def forward(self, x):
class ESEModule(nn.Layer):
def __init__(self, channels):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
if "npu" in paddle.device.get_device():
self.avg_pool = MeanPool2D(1, 1)
else:
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv = Conv2D(
in_channels=channels,
out_channels=channels,
Expand Down
6 changes: 5 additions & 1 deletion ppocr/modeling/backbones/rec_lcnetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ReLU,
)
from paddle.regularizer import L2Decay
from ppocr.modeling.backbones.rec_hgnet import MeanPool2D

NET_CONFIG_det = {
"blocks2":
Expand Down Expand Up @@ -310,7 +311,10 @@ def _fuse_bn_tensor(self, branch):
class SELayer(nn.Layer):
def __init__(self, channel, reduction=4, lr_mult=1.0):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
if "npu" in paddle.device.get_device():
self.avg_pool = MeanPool2D(1, 1)
else:
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
Expand Down
1 change: 1 addition & 0 deletions ppocr/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def load_model(config, model, optimizer=None, model_type="det"):
pickle.load(f) if six.PY2 else pickle.load(f, encoding="latin1")
)
best_model_dict = states_dict.get("best_model_dict", {})
best_model_dict["acc"] = 0.0
if "epoch" in states_dict:
best_model_dict["start_epoch"] = states_dict["epoch"] + 1
logger.info("resume from {}".format(checkpoints))
Expand Down
31 changes: 21 additions & 10 deletions ppocr/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,27 @@ def _check_image_file(path):

def get_image_file_list(img_file, infer_list=None):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))

if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if infer_list and not os.path.exists(infer_list):
raise Exception("not found infer list {}".format(infer_list))
if infer_list:
with open(infer_list, "r") as f:
lines = f.readlines()
for line in lines:
image_path = line.strip().split("\t")[0]
image_path = os.path.join(img_file, image_path)
imgs_lists.append(image_path)
else:
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))

img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)

if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
Expand Down
2 changes: 1 addition & 1 deletion ppstructure/kie/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ seqeval
pypandoc
attrdict3
python_docx
paddlenlp>=2.4.1
paddlenlp>=2.5.2
37 changes: 36 additions & 1 deletion tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

import argparse

import yaml
import paddle
from paddle.jit import to_static

from collections import OrderedDict
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
Expand Down Expand Up @@ -201,6 +202,38 @@ def export_single_model(
return


def represent_dictionary_order(self, dict_data):
return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items())


def setup_orderdict():
yaml.add_representer(OrderedDict, represent_dictionary_order)


def dump_infer_config(config, path, logger):
setup_orderdict()
infer_cfg = OrderedDict()

infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
postprocess = OrderedDict()
for k, v in config["PostProcess"].items():
postprocess[k] = v

if config["Global"].get("character_dict_path") is not None:
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
lines = f.readlines()
character_dict = [line.strip("\n") for line in lines]
postprocess["character_dict"] = character_dict

infer_cfg["PostProcess"] = postprocess

with open(path, "w") as f:
yaml.dump(
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
)
logger.info("Export inference config file to {}".format(os.path.join(path)))


def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
Expand Down Expand Up @@ -260,6 +293,7 @@ def main():
model.eval()

save_path = config["Global"]["save_inference_dir"]
yaml_path = os.path.join(save_path, "inference.yml")

arch_config = config["Architecture"]

Expand Down Expand Up @@ -294,6 +328,7 @@ def main():
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape
)
dump_infer_config(config, yaml_path, logger)


if __name__ == "__main__":
Expand Down

0 comments on commit 6c12df4

Please sign in to comment.