Skip to content

Commit

Permalink
fix lateocr bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Sep 28, 2024
1 parent 2b51369 commit fedc397
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ppocr/modeling/backbones/rec_resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def __init__(
self.eps = eps

def forward(self, x):
if not self.training:
self.export = True
if self.same_pad:
if self.export:
x = pad_same_export(x, self._kernel_size, self._stride, self._dilation)
Expand Down Expand Up @@ -201,6 +203,8 @@ def __init__(
)

def forward(self, x):
if not self.training:
self.export = True
if self.export:
x = pad_same_export(x, self.ksize, self.stride, value=-float("inf"))
else:
Expand Down
2 changes: 2 additions & 0 deletions ppocr/modeling/heads/rec_latexocr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def forward(
mem=None,
seq_len=0,
):
if not self.training:
self.is_export = True
b, n, _, h, talking_heads, collab_heads, has_context = (
*x.shape,
self.heads,
Expand Down
1 change: 1 addition & 0 deletions ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def dump_infer_config(config, path, logger):
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
hpi_config["Hpi"]["selected_backends"]["gpu"] = "paddle_infer"
infer_cfg["Hpi"] = hpi_config["Hpi"]
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {}
Expand Down
6 changes: 5 additions & 1 deletion ppocr/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
metric_score = metric_info["metric"]["acc"]
elif "precision" in metric_info["metric"]:
metric_score = metric_info["metric"]["precision"]
elif "exp_rate" in metric_info["metric"]:
metric_score = metric_info["metric"]["exp_rate"]
else:
raise ValueError("No metric score found.")
train_results["models"]["best"]["score"] = metric_score
Expand All @@ -326,8 +328,10 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
metric_score = metric_info["metric"]["acc"]
elif "precision" in metric_info["metric"]:
metric_score = metric_info["metric"]["precision"]
elif "exp_rate" in metric_info["metric"]:
metric_score = metric_info["metric"]["exp_rate"]
else:
raise ValueError("No metric score found.")
metric_score = 0
train_results["models"][f"last_{1}"]["score"] = metric_score
for tag in save_model_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
Expand Down

0 comments on commit fedc397

Please sign in to comment.