diff --git a/ppocr/utils/export_model.py b/ppocr/utils/export_model.py new file mode 100644 index 0000000000..ce4c16e5ce --- /dev/null +++ b/ppocr/utils/export_model.py @@ -0,0 +1,381 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import json +import copy +import paddle +import paddle.nn as nn +from paddle.jit import to_static + +from collections import OrderedDict +from argparse import ArgumentParser, RawDescriptionHelpFormatter +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.logging import get_logger + + +def represent_dictionary_order(self, dict_data): + return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items()) + + +def setup_orderdict(): + yaml.add_representer(OrderedDict, represent_dictionary_order) + + +def dump_infer_config(config, path, logger): + setup_orderdict() + infer_cfg = OrderedDict() + if config["Global"].get("hpi_config_path", None): + hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r")) + rec_resize_img_dict = next( + ( + item + for item in config["Eval"]["dataset"]["transforms"] + if "RecResizeImg" in item + ), + None, + ) + if rec_resize_img_dict: + dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"] + if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None): + hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][ + "dynamic_shapes" + ]["x"] = [dynamic_shapes for i in range(3)] + hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][ + "max_batch_size" + ] = 1 + if hpi_config["Hpi"]["backend_config"].get("tensorrt", None): + hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][ + "x" + ] = [dynamic_shapes for i in range(3)] + hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1 + else: + if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None): + hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt") + del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"] + if hpi_config["Hpi"]["backend_config"].get("tensorrt", None): + hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt") + del hpi_config["Hpi"]["backend_config"]["tensorrt"] + infer_cfg["Hpi"] = hpi_config["Hpi"] + if config["Global"].get("pdx_model_name", None): + infer_cfg["Global"] = {} + infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"] + + infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]} + postprocess = OrderedDict() + for k, v in config["PostProcess"].items(): + postprocess[k] = v + + if config["Architecture"].get("algorithm") in ["LaTeXOCR"]: + tokenizer_file = config["Global"].get("rec_char_dict_path") + if tokenizer_file is not None: + with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle: + character_dict = json.load(tokenizer_config_handle) + postprocess["character_dict"] = character_dict + else: + if config["Global"].get("character_dict_path") is not None: + with open(config["Global"]["character_dict_path"], encoding="utf-8") as f: + lines = f.readlines() + character_dict = [line.strip("\n") for line in lines] + postprocess["character_dict"] = character_dict + + infer_cfg["PostProcess"] = postprocess + + with open(path, "w") as f: + yaml.dump( + infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True + ) + logger.info("Export inference config file to {}".format(os.path.join(path))) + + +def export_single_model( + model, arch_config, save_path, logger, input_shape=None, quanter=None +): + if arch_config["algorithm"] == "SRN": + max_text_length = arch_config["Head"]["max_text_length"] + other_shape = [ + paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"), + [ + paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"), + paddle.static.InputSpec( + shape=[None, max_text_length, 1], dtype="int64" + ), + paddle.static.InputSpec( + shape=[None, 8, max_text_length, max_text_length], dtype="int64" + ), + paddle.static.InputSpec( + shape=[None, 8, max_text_length, max_text_length], dtype="int64" + ), + ], + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SAR": + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"), + [paddle.static.InputSpec(shape=[None], dtype="float32")], + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]: + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["SVTR", "CPPD"]: + other_shape = [ + paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "PREN": + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["model_type"] == "sr": + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32") + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "ViTSTR": + other_shape = [ + paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "ABINet": + if not input_shape: + input_shape = [3, 32, 128] + other_shape = [ + paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]: + other_shape = [ + paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["SATRN"]: + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "VisionLAN": + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "RobustScanner": + max_text_length = arch_config["Head"]["max_text_length"] + other_shape = [ + paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"), + [ + paddle.static.InputSpec( + shape=[ + None, + ], + dtype="float32", + ), + paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"), + ], + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "CAN": + other_shape = [ + [ + paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), + paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), + paddle.static.InputSpec( + shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64" + ), + ] + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "LaTeXOCR": + other_shape = [ + paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: + input_spec = [ + paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids + paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox + paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask + paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids + paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image + ] + if "Re" in arch_config["Backbone"]["name"]: + input_spec.extend( + [ + paddle.static.InputSpec( + shape=[None, 512, 3], dtype="int64" + ), # entities + paddle.static.InputSpec( + shape=[None, None, 2], dtype="int64" + ), # relations + ] + ) + if model.backbone.use_visual_backbone is False: + input_spec.pop(4) + model = to_static(model, input_spec=[input_spec]) + else: + infer_shape = [3, -1, -1] + if arch_config["model_type"] == "rec": + infer_shape = [3, 32, -1] # for rec model, H must be 32 + if ( + "Transform" in arch_config + and arch_config["Transform"] is not None + and arch_config["Transform"]["name"] == "TPS" + ): + logger.info( + "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" + ) + infer_shape[-1] = 100 + elif arch_config["model_type"] == "table": + infer_shape = [3, 488, 488] + if arch_config["algorithm"] == "TableMaster": + infer_shape = [3, 480, 480] + if arch_config["algorithm"] == "SLANet": + infer_shape = [3, -1, -1] + model = to_static( + model, + input_spec=[ + paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32") + ], + ) + + if ( + arch_config["model_type"] != "sr" + and arch_config["Backbone"]["name"] == "PPLCNetV3" + ): + # for rep lcnetv3 + for layer in model.sublayers(): + if hasattr(layer, "rep") and not getattr(layer, "is_repped"): + layer.rep() + + if quanter is None: + paddle.jit.save(model, save_path) + else: + quanter.save_quantized_model(model, save_path) + logger.info("inference model is saved to {}".format(save_path)) + return + + +def export(config, base_model=None, save_path=None): + if paddle.distributed.get_rank() != 0: + return + logger = get_logger() + # build post process + post_process_class = build_post_process(config["PostProcess"], config["Global"]) + + # build model + # for rec algorithm + if hasattr(post_process_class, "character"): + char_num = len(getattr(post_process_class, "character")) + if config["Architecture"]["algorithm"] in [ + "Distillation", + ]: # distillation model + for key in config["Architecture"]["Models"]: + if ( + config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" + ): # multi head + out_channels_list = {} + if config["PostProcess"]["name"] == "DistillationSARLabelDecode": + char_num = char_num - 2 + if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode": + char_num = char_num - 3 + out_channels_list["CTCLabelDecode"] = char_num + out_channels_list["SARLabelDecode"] = char_num + 2 + out_channels_list["NRTRLabelDecode"] = char_num + 3 + config["Architecture"]["Models"][key]["Head"][ + "out_channels_list" + ] = out_channels_list + else: + config["Architecture"]["Models"][key]["Head"][ + "out_channels" + ] = char_num + # just one final tensor needs to exported for inference + config["Architecture"]["Models"][key]["return_all_feats"] = False + elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head + out_channels_list = {} + char_num = len(getattr(post_process_class, "character")) + if config["PostProcess"]["name"] == "SARLabelDecode": + char_num = char_num - 2 + if config["PostProcess"]["name"] == "NRTRLabelDecode": + char_num = char_num - 3 + out_channels_list["CTCLabelDecode"] = char_num + out_channels_list["SARLabelDecode"] = char_num + 2 + out_channels_list["NRTRLabelDecode"] = char_num + 3 + config["Architecture"]["Head"]["out_channels_list"] = out_channels_list + else: # base rec model + config["Architecture"]["Head"]["out_channels"] = char_num + + # for sr algorithm + if config["Architecture"]["model_type"] == "sr": + config["Architecture"]["Transform"]["infer_mode"] = True + + # for latexocr algorithm + if config["Architecture"].get("algorithm") in ["LaTeXOCR"]: + config["Architecture"]["Backbone"]["is_predict"] = True + config["Architecture"]["Backbone"]["is_export"] = True + config["Architecture"]["Head"]["is_export"] = True + if base_model is not None: + model = base_model + if isinstance(model, paddle.DataParallel): + model = copy.deepcopy(model._layers) + else: + model = copy.deepcopy(model) + else: + model = build_model(config["Architecture"]) + load_model(config, model, model_type=config["Architecture"]["model_type"]) + model.eval() + + if not save_path: + save_path = config["Global"]["save_inference_dir"] + yaml_path = os.path.join(save_path, "inference.yml") + + arch_config = config["Architecture"] + + if ( + arch_config["algorithm"] in ["SVTR", "CPPD"] + and arch_config["Head"]["name"] != "MultiHead" + ): + input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][ + "image_shape" + ] + elif arch_config["algorithm"].lower() == "ABINet".lower(): + rec_rs = [ + c + for c in config["Eval"]["dataset"]["transforms"] + if "ABINetRecResizeImg" in c + ] + input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None + else: + input_shape = None + + if arch_config["algorithm"] in [ + "Distillation", + ]: # distillation model + archs = list(arch_config["Models"].values()) + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model( + model.model_list[idx], archs[idx], sub_model_save_path, logger + ) + else: + save_path = os.path.join(save_path, "inference") + export_single_model( + model, arch_config, save_path, logger, input_shape=input_shape + ) + dump_infer_config(config, yaml_path, logger) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index c0fdaf530a..f7ee432271 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -20,6 +20,7 @@ import os import pickle import six +import json import paddle @@ -248,6 +249,15 @@ def save_model( if prefix == "best_accuracy": arch.backbone.model.save_pretrained(best_model_path) + save_model_info = kwargs.pop("save_model_info", False) + if save_model_info: + with open(os.path.join(model_path, f"{prefix}.info.json"), "w") as f: + json.dump(kwargs, f) + logger.info("Already save model info in {}".format(model_path)) + if prefix != "latest": + done_flag = kwargs.pop("done_flag", False) + update_train_results(config, prefix, save_model_info, done_flag=done_flag) + # save metric and config with open(metric_prefix + ".states", "wb") as f: pickle.dump(kwargs, f, protocol=2) @@ -255,3 +265,80 @@ def save_model( logger.info("save best model is to {}".format(model_prefix)) else: logger.info("save model in {}".format(model_prefix)) + + +def update_train_results(config, prefix, metric_info, done_flag=False, last_num=5): + if paddle.distributed.get_rank() != 0: + return + + assert last_num >= 1 + train_results_path = os.path.join( + config["Global"]["save_model_dir"], "train_results.json" + ) + save_model_tag = ["pdparams", "pdopt", "pdstates"] + save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"] + if os.path.exists(train_results_path): + with open(train_results_path, "r") as fp: + train_results = json.load(fp) + else: + train_results = {} + train_results["model_name"] = config["Global"]["pdx_model_name"] + label_dict_path = os.path.abspath( + config["Global"].get("character_dict_path", "") + ) + if label_dict_path != "": + if not os.path.exists(label_dict_path): + label_dict_path = "" + label_dict_path = label_dict_path + train_results["label_dict"] = label_dict_path + train_results["train_log"] = "train.log" + train_results["visualdl_log"] = "" + train_results["config"] = "config.yaml" + train_results["models"] = {} + for i in range(1, last_num + 1): + train_results["models"][f"last_{i}"] = {} + train_results["models"]["best"] = {} + train_results["done_flag"] = done_flag + if "best" in prefix: + if "acc" in metric_info["metric"]: + metric_score = metric_info["metric"]["acc"] + elif "precision" in metric_info["metric"]: + metric_score = metric_info["metric"]["precision"] + else: + raise ValueError("No metric score found.") + train_results["models"]["best"]["score"] = metric_score + for tag in save_model_tag: + train_results["models"]["best"][tag] = os.path.join( + prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states" + ) + for tag in save_inference_tag: + train_results["models"]["best"][tag] = os.path.join( + prefix, + "inference", + f"inference.{tag}" if tag != "inference_config" else "inference.yml", + ) + else: + for i in range(last_num - 1, 0, -1): + train_results["models"][f"last_{i + 1}"] = train_results["models"][ + f"last_{i}" + ].copy() + if "acc" in metric_info["metric"]: + metric_score = metric_info["metric"]["acc"] + elif "precision" in metric_info["metric"]: + metric_score = metric_info["metric"]["precision"] + else: + raise ValueError("No metric score found.") + train_results["models"][f"last_{1}"]["score"] = metric_score + for tag in save_model_tag: + train_results["models"][f"last_{1}"][tag] = os.path.join( + prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states" + ) + for tag in save_inference_tag: + train_results["models"][f"last_{1}"][tag] = os.path.join( + prefix, + "inference", + f"inference.{tag}" if tag != "inference_config" else "inference.yml", + ) + + with open(train_results_path, "w") as fp: + json.dump(train_results, fp) diff --git a/tools/export_model.py b/tools/export_model.py index 3ffebd4020..da1bcbb3ea 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -21,328 +21,16 @@ import argparse -import yaml -import json -import paddle -from paddle.jit import to_static -from collections import OrderedDict -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser - - -def export_single_model( - model, arch_config, save_path, logger, input_shape=None, quanter=None -): - if arch_config["algorithm"] == "SRN": - max_text_length = arch_config["Head"]["max_text_length"] - other_shape = [ - paddle.static.InputSpec(shape=[None, 1, 64, 256], dtype="float32"), - [ - paddle.static.InputSpec(shape=[None, 256, 1], dtype="int64"), - paddle.static.InputSpec( - shape=[None, max_text_length, 1], dtype="int64" - ), - paddle.static.InputSpec( - shape=[None, 8, max_text_length, max_text_length], dtype="int64" - ), - paddle.static.InputSpec( - shape=[None, 8, max_text_length, max_text_length], dtype="int64" - ), - ], - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "SAR": - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"), - [paddle.static.InputSpec(shape=[None], dtype="float32")], - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]: - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 48, -1], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] in ["SVTR", "CPPD"]: - other_shape = [ - paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "PREN": - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["model_type"] == "sr": - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 16, 64], dtype="float32") - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "ViTSTR": - other_shape = [ - paddle.static.InputSpec(shape=[None, 1, 224, 224], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "ABINet": - if not input_shape: - input_shape = [3, 32, 128] - other_shape = [ - paddle.static.InputSpec(shape=[None] + input_shape, dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] in ["NRTR", "SPIN", "RFL"]: - other_shape = [ - paddle.static.InputSpec(shape=[None, 1, 32, 100], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] in ["SATRN"]: - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 32, 100], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "VisionLAN": - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 64, 256], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "RobustScanner": - max_text_length = arch_config["Head"]["max_text_length"] - other_shape = [ - paddle.static.InputSpec(shape=[None, 3, 48, 160], dtype="float32"), - [ - paddle.static.InputSpec( - shape=[ - None, - ], - dtype="float32", - ), - paddle.static.InputSpec(shape=[None, max_text_length], dtype="int64"), - ], - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "CAN": - other_shape = [ - [ - paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), - paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), - paddle.static.InputSpec( - shape=[None, arch_config["Head"]["max_text_length"]], dtype="int64" - ), - ] - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "LaTeXOCR": - other_shape = [ - paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: - input_spec = [ - paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids - paddle.static.InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox - paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # attention_mask - paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids - paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="int64"), # image - ] - if "Re" in arch_config["Backbone"]["name"]: - input_spec.extend( - [ - paddle.static.InputSpec( - shape=[None, 512, 3], dtype="int64" - ), # entities - paddle.static.InputSpec( - shape=[None, None, 2], dtype="int64" - ), # relations - ] - ) - if model.backbone.use_visual_backbone is False: - input_spec.pop(4) - model = to_static(model, input_spec=[input_spec]) - else: - infer_shape = [3, -1, -1] - if arch_config["model_type"] == "rec": - infer_shape = [3, 32, -1] # for rec model, H must be 32 - if ( - "Transform" in arch_config - and arch_config["Transform"] is not None - and arch_config["Transform"]["name"] == "TPS" - ): - logger.info( - "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" - ) - infer_shape[-1] = 100 - elif arch_config["model_type"] == "table": - infer_shape = [3, 488, 488] - if arch_config["algorithm"] == "TableMaster": - infer_shape = [3, 480, 480] - if arch_config["algorithm"] == "SLANet": - infer_shape = [3, -1, -1] - model = to_static( - model, - input_spec=[ - paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32") - ], - ) - - if ( - arch_config["model_type"] != "sr" - and arch_config["Backbone"]["name"] == "PPLCNetV3" - ): - # for rep lcnetv3 - for layer in model.sublayers(): - if hasattr(layer, "rep") and not getattr(layer, "is_repped"): - layer.rep() - - if quanter is None: - paddle.jit.save(model, save_path) - else: - quanter.save_quantized_model(model, save_path) - logger.info("inference model is saved to {}".format(save_path)) - return - - -def represent_dictionary_order(self, dict_data): - return self.represent_mapping("tag:yaml.org,2002:map", dict_data.items()) - - -def setup_orderdict(): - yaml.add_representer(OrderedDict, represent_dictionary_order) - - -def dump_infer_config(config, path, logger): - setup_orderdict() - infer_cfg = OrderedDict() - - infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]} - postprocess = OrderedDict() - for k, v in config["PostProcess"].items(): - postprocess[k] = v - - if config["Architecture"].get("algorithm") in ["LaTeXOCR"]: - tokenizer_file = config["Global"].get("rec_char_dict_path") - if tokenizer_file is not None: - with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle: - character_dict = json.load(tokenizer_config_handle) - postprocess["character_dict"] = character_dict - else: - if config["Global"].get("character_dict_path") is not None: - with open(config["Global"]["character_dict_path"], encoding="utf-8") as f: - lines = f.readlines() - character_dict = [line.strip("\n") for line in lines] - postprocess["character_dict"] = character_dict - - infer_cfg["PostProcess"] = postprocess - - with open(path, "w") as f: - yaml.dump( - infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True - ) - logger.info("Export inference config file to {}".format(os.path.join(path))) +from ppocr.utils.export_model import export def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) config = merge_config(config, FLAGS.opt) - logger = get_logger() - # build post process - - post_process_class = build_post_process(config["PostProcess"], config["Global"]) - - # build model - # for rec algorithm - if hasattr(post_process_class, "character"): - char_num = len(getattr(post_process_class, "character")) - if config["Architecture"]["algorithm"] in [ - "Distillation", - ]: # distillation model - for key in config["Architecture"]["Models"]: - if ( - config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" - ): # multi head - out_channels_list = {} - if config["PostProcess"]["name"] == "DistillationSARLabelDecode": - char_num = char_num - 2 - if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode": - char_num = char_num - 3 - out_channels_list["CTCLabelDecode"] = char_num - out_channels_list["SARLabelDecode"] = char_num + 2 - out_channels_list["NRTRLabelDecode"] = char_num + 3 - config["Architecture"]["Models"][key]["Head"][ - "out_channels_list" - ] = out_channels_list - else: - config["Architecture"]["Models"][key]["Head"][ - "out_channels" - ] = char_num - # just one final tensor needs to exported for inference - config["Architecture"]["Models"][key]["return_all_feats"] = False - elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head - out_channels_list = {} - char_num = len(getattr(post_process_class, "character")) - if config["PostProcess"]["name"] == "SARLabelDecode": - char_num = char_num - 2 - if config["PostProcess"]["name"] == "NRTRLabelDecode": - char_num = char_num - 3 - out_channels_list["CTCLabelDecode"] = char_num - out_channels_list["SARLabelDecode"] = char_num + 2 - out_channels_list["NRTRLabelDecode"] = char_num + 3 - config["Architecture"]["Head"]["out_channels_list"] = out_channels_list - else: # base rec model - config["Architecture"]["Head"]["out_channels"] = char_num - - # for sr algorithm - if config["Architecture"]["model_type"] == "sr": - config["Architecture"]["Transform"]["infer_mode"] = True - - # for latexocr algorithm - if config["Architecture"].get("algorithm") in ["LaTeXOCR"]: - config["Architecture"]["Backbone"]["is_predict"] = True - config["Architecture"]["Backbone"]["is_export"] = True - config["Architecture"]["Head"]["is_export"] = True - model = build_model(config["Architecture"]) - load_model(config, model, model_type=config["Architecture"]["model_type"]) - model.eval() - - save_path = config["Global"]["save_inference_dir"] - yaml_path = os.path.join(save_path, "inference.yml") - - arch_config = config["Architecture"] - - if ( - arch_config["algorithm"] in ["SVTR", "CPPD"] - and arch_config["Head"]["name"] != "MultiHead" - ): - input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][ - "image_shape" - ] - elif arch_config["algorithm"].lower() == "ABINet".lower(): - rec_rs = [ - c - for c in config["Eval"]["dataset"]["transforms"] - if "ABINetRecResizeImg" in c - ] - input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None - else: - input_shape = None - - if arch_config["algorithm"] in [ - "Distillation", - ]: # distillation model - archs = list(arch_config["Models"].values()) - for idx, name in enumerate(model.model_name_list): - sub_model_save_path = os.path.join(save_path, name, "inference") - export_single_model( - model.model_list[idx], archs[idx], sub_model_save_path, logger - ) - else: - save_path = os.path.join(save_path, "inference") - export_single_model( - model, arch_config, save_path, logger, input_shape=input_shape - ) - dump_infer_config(config, yaml_path, logger) + # export model + export(config) if __name__ == "__main__": diff --git a/tools/program.py b/tools/program.py index 882f03f58f..fc56275ba1 100755 --- a/tools/program.py +++ b/tools/program.py @@ -27,6 +27,7 @@ from tqdm import tqdm import cv2 import numpy as np +import copy from argparse import ArgumentParser, RawDescriptionHelpFormatter from ppocr.utils.stats import TrainingStats @@ -36,6 +37,7 @@ from ppocr.utils.loggers import WandbLogger, Loggers from ppocr.utils import profiler from ppocr.data import build_dataloader +from ppocr.utils.export_model import export class ArgsParser(ArgumentParser): @@ -205,6 +207,7 @@ def train( eval_batch_epoch = config["Global"].get("eval_batch_epoch", None) profiler_options = config["profiler_options"] print_mem_info = config["Global"].get("print_mem_info", True) + uniform_output_enabled = config["Global"].get("uniform_output_enabled", False) global_step = 0 if "global_step" in pre_best_model_dict: @@ -303,6 +306,7 @@ def train( ) for idx, batch in enumerate(train_dataloader): + model.train() profiler.add_profiler_step(profiler_options) train_reader_cost += time.time() - reader_start if idx >= max_iter: @@ -484,14 +488,29 @@ def train( if cur_metric[main_indicator] >= best_model_dict[main_indicator]: best_model_dict.update(cur_metric) best_model_dict["best_epoch"] = epoch + prefix = "best_accuracy" + if uniform_output_enabled: + export( + config, + model, + os.path.join(save_model_dir, prefix, "inference"), + ) + model_info = {"epoch": epoch, "metric": best_model_dict} + else: + model_info = None save_model( model, optimizer, - save_model_dir, + ( + os.path.join(save_model_dir, prefix) + if uniform_output_enabled + else save_model_dir + ), logger, config, is_best=True, - prefix="best_accuracy", + prefix=prefix, + save_model_info=model_info, best_model_dict=best_model_dict, epoch=epoch, global_step=global_step, @@ -520,14 +539,25 @@ def train( reader_start = time.time() if dist.get_rank() == 0: + prefix = "latest" + if uniform_output_enabled: + export(config, model, os.path.join(save_model_dir, prefix, "inference")) + model_info = {"epoch": epoch, "metric": best_model_dict} + else: + model_info = None save_model( model, optimizer, - save_model_dir, + ( + os.path.join(save_model_dir, prefix) + if uniform_output_enabled + else save_model_dir + ), logger, config, is_best=False, - prefix="latest", + prefix=prefix, + save_model_info=model_info, best_model_dict=best_model_dict, epoch=epoch, global_step=global_step, @@ -537,17 +567,29 @@ def train( log_writer.log_model(is_best=False, prefix="latest") if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: + prefix = "iter_epoch_{}".format(epoch) + if uniform_output_enabled: + export(config, model, os.path.join(save_model_dir, prefix, "inference")) + model_info = {"epoch": epoch, "metric": best_model_dict} + else: + model_info = None save_model( model, optimizer, - save_model_dir, + ( + os.path.join(save_model_dir, prefix) + if uniform_output_enabled + else save_model_dir + ), logger, config, is_best=False, - prefix="iter_epoch_{}".format(epoch), + prefix=prefix, + save_model_info=model_info, best_model_dict=best_model_dict, epoch=epoch, global_step=global_step, + done_flag=epoch == config["Global"]["epoch_num"], ) if log_writer is not None: log_writer.log_model( diff --git a/tools/train.py b/tools/train.py index 0aaa0089f9..3bc9906d00 100755 --- a/tools/train.py +++ b/tools/train.py @@ -166,6 +166,15 @@ def main(config, device, logger, vdl_writer, seed): amp_dtype = config["Global"].get("amp_dtype", "float16") amp_custom_black_list = config["Global"].get("amp_custom_black_list", []) amp_custom_white_list = config["Global"].get("amp_custom_white_list", []) + if os.path.exists( + os.path.join(config["Global"]["save_model_dir"], "train_results.json") + ): + try: + os.remove( + os.path.join(config["Global"]["save_model_dir"], "train_results.json") + ) + except: + pass if use_amp: AMP_RELATED_FLAGS_SETTING = { "FLAGS_max_inplace_grad_add": 8,