diff --git a/calamari_ocr/ocr/dataset/data.py b/calamari_ocr/ocr/dataset/data.py index 53fcab62..2f46fdce 100644 --- a/calamari_ocr/ocr/dataset/data.py +++ b/calamari_ocr/ocr/dataset/data.py @@ -1,8 +1,9 @@ import logging import os -from typing import Type, Optional +from typing import Callable, Dict, Type, Optional import tensorflow as tf +from tfaip.util.tftyping import AnyTensor from tfaip.data.data import DataBase from tfaip.data.databaseparams import DataPipelineParams from tfaip.data.pipeline.datapipeline import DataPipeline @@ -84,6 +85,11 @@ def _target_layer_specs(self): "gt_len": tf.TensorSpec([1], dtype=tf.int32), } + def element_length_fn(self) -> Callable[[Dict[str, AnyTensor]], AnyTensor]: + def img_len(x): + return x["img_len"] + return img_len + def create_pipeline( self, pipeline_params: DataPipelineParams,