diff --git a/docs/source/en/main_classes/pipelines.mdx b/docs/source/en/main_classes/pipelines.mdx index e5ee3902028e34..acb4e10c80bbd6 100644 --- a/docs/source/en/main_classes/pipelines.mdx +++ b/docs/source/en/main_classes/pipelines.mdx @@ -446,6 +446,12 @@ Pipelines available for multimodal tasks include the following. - __call__ - all +### DocumentTokenClassificationPipeline + +[[autodoc]] DocumentTokenClassificationPipeline + - __call__ + - all + ### FeatureExtractionPipeline [[autodoc]] FeatureExtractionPipeline diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index b39920151db424..80beec71ba23aa 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -310,6 +310,14 @@ The following auto classes are available for the following multimodal tasks. [[autodoc]] TFAutoModelForDocumentQuestionAnswering +### AutoModelForDocumentTokenClassification + +[[autodoc]] AutoModelForDocumentTokenClassification + +### TFAutoModelForDocumentTokenClassification + +[[autodoc]] TFAutoModelForDocumentTokenClassification + ### AutoModelForVisualQuestionAnswering [[autodoc]] AutoModelForVisualQuestionAnswering diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bf53737e968992..c0b22afb2ebde0 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -481,6 +481,7 @@ "CsvPipelineDataFormat", "DepthEstimationPipeline", "DocumentQuestionAnsweringPipeline", + "DocumentTokenClassificationPipeline", "FeatureExtractionPipeline", "FillMaskPipeline", "ImageClassificationPipeline", @@ -938,6 +939,7 @@ "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", @@ -970,6 +972,7 @@ "AutoModelForCTC", "AutoModelForDepthEstimation", "AutoModelForDocumentQuestionAnswering", + "AutoModelForDocumentTokenClassification", "AutoModelForImageClassification", "AutoModelForImageSegmentation", "AutoModelForInstanceSegmentation", @@ -2531,6 +2534,7 @@ [ "TF_MODEL_FOR_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING", @@ -2550,6 +2554,7 @@ "TFAutoModel", "TFAutoModelForCausalLM", "TFAutoModelForDocumentQuestionAnswering", + "TFAutoModelForDocumentTokenClassification", "TFAutoModelForImageClassification", "TFAutoModelForMaskedLM", "TFAutoModelForMultipleChoice", @@ -3796,6 +3801,7 @@ CsvPipelineDataFormat, DepthEstimationPipeline, DocumentQuestionAnsweringPipeline, + DocumentTokenClassificationPipeline, FeatureExtractionPipeline, FillMaskPipeline, ImageClassificationPipeline, @@ -4183,6 +4189,7 @@ MODEL_FOR_CTC_MAPPING, MODEL_FOR_DEPTH_ESTIMATION_MAPPING, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, @@ -4215,6 +4222,7 @@ AutoModelForCTC, AutoModelForDepthEstimation, AutoModelForDocumentQuestionAnswering, + AutoModelForDocumentTokenClassification, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForInstanceSegmentation, @@ -5497,6 +5505,7 @@ from .models.auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, @@ -5516,6 +5525,7 @@ TFAutoModel, TFAutoModelForCausalLM, TFAutoModelForDocumentQuestionAnswering, + TFAutoModelForDocumentTokenClassification, TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index da8ceb8e7e6258..6683aa2b85e6cc 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -50,6 +50,7 @@ "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", @@ -103,6 +104,7 @@ "AutoModelForVision2Seq", "AutoModelForVisualQuestionAnswering", "AutoModelForDocumentQuestionAnswering", + "AutoModelForDocumentTokenClassification", "AutoModelWithLMHead", "AutoModelForZeroShotObjectDetection", ] @@ -123,6 +125,7 @@ "TF_MODEL_FOR_PRETRAINING_MAPPING", "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", @@ -140,6 +143,7 @@ "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", "TFAutoModelForDocumentQuestionAnswering", + "TFAutoModelForDocumentTokenClassification", "TFAutoModelForQuestionAnswering", "TFAutoModelForSemanticSegmentation", "TFAutoModelForSeq2SeqLM", @@ -208,6 +212,7 @@ MODEL_FOR_CTC_MAPPING, MODEL_FOR_DEPTH_ESTIMATION_MAPPING, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, @@ -240,6 +245,7 @@ AutoModelForCTC, AutoModelForDepthEstimation, AutoModelForDocumentQuestionAnswering, + AutoModelForDocumentTokenClassification, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForInstanceSegmentation, @@ -273,6 +279,7 @@ from .modeling_tf_auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, @@ -292,6 +299,7 @@ TFAutoModel, TFAutoModelForCausalLM, TFAutoModelForDocumentQuestionAnswering, + TFAutoModelForDocumentTokenClassification, TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4d61c4c972ed05..c1efb96990962d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -716,6 +716,12 @@ ] ) +MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("layoutlmv3", "LayoutLMv3ForTokenClassification"), + ] +) + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping @@ -926,6 +932,9 @@ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES ) +MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES +) MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES @@ -1060,6 +1069,15 @@ class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', ) +class AutoModelForDocumentTokenClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING + +AutoModelForDocumentTokenClassification = auto_class_update( + AutoModelForDocumentTokenClassification, + head_doc="document token classification", + checkpoint_for_example='microsoft/layoutlmv3-base", revision="07c9b08', +) + class AutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index c77fba4f66fac6..b95facb6dc2cf3 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -344,6 +344,11 @@ ] ) +TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"), + ] +) TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ @@ -442,6 +447,9 @@ TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES ) +TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES +) TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES ) @@ -561,6 +569,14 @@ class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', ) +class TFAutoModelForDocumentTokenClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING + +TFAutoModelForDocumentTokenClassification = auto_class_update( + TFAutoModelForDocumentTokenClassification, + head_doc="document token classification", + checkpoint_for_example='microsoft/layoutlmv3-base", revision="07c9b08', +) class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 8b06009a4cd14b..61cada6386397c 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -62,6 +62,7 @@ from .conversational import Conversation, ConversationalPipeline from .depth_estimation import DepthEstimationPipeline from .document_question_answering import DocumentQuestionAnsweringPipeline +from .document_token_classification import DocumentTokenClassificationPipeline from .feature_extraction import FeatureExtractionPipeline from .fill_mask import FillMaskPipeline from .image_classification import ImageClassificationPipeline @@ -123,6 +124,7 @@ AutoModelForCausalLM, AutoModelForCTC, AutoModelForDocumentQuestionAnswering, + AutoModelForDocumentTokenClassification, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForMaskedLM, @@ -240,6 +242,18 @@ }, "type": "multimodal", }, + "document-token-classification": { + "impl": DocumentTokenClassificationPipeline, + "pt": (AutoModelForDocumentTokenClassification,) if is_torch_available() else (), + "tf": (), + "default": { + "model": { + "pt": ("microsoft/layoutlmv3-base", "07c9b08"), + "tf": ("microsoft/layoutlmv3-base", "07c9b08"), + }, + }, + "type": "multimodal", + }, "fill-mask": { "impl": FillMaskPipeline, "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), diff --git a/src/transformers/pipelines/document_token_classification.py b/src/transformers/pipelines/document_token_classification.py new file mode 100644 index 00000000000000..1b06652a7c8f57 --- /dev/null +++ b/src/transformers/pipelines/document_token_classification.py @@ -0,0 +1,254 @@ +# Copyright 2022 The Loop Team and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List, Optional, Tuple, Union, Dict + +import numpy as np + +from ..utils import ( + ExplicitEnum, + add_end_docstrings, + is_pytesseract_available, + is_torch_available, + is_vision_available, + logging, +) +from .base import PIPELINE_INIT_ARGS, Pipeline, ArgumentHandler, Dataset, types + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_image + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING + +TESSERACT_LOADED = False +if is_pytesseract_available(): + TESSERACT_LOADED = True + import pytesseract + +logger = logging.get_logger(__name__) + + +class ModelType(ExplicitEnum): + LayoutLMv3 = "layoutlmv3" + + +class DocumentTokenClassificationArgumentHandler(ArgumentHandler): + """ + Handles arguments for token classification. + """ + + def __call__(self, inputs: Union[str, List[str]], **kwargs): + + if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: + return list(inputs) + elif isinstance(inputs, str) or isinstance(inputs, Image.Image) or isinstance(inputs, dict): + return [inputs] + elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType): + return inputs + else: + raise ValueError("At least one input is required.") + return inputs + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class DocumentTokenClassificationPipeline(Pipeline): + # TODO: Update task_summary docs to include an example with document token classification + """ + Document Token Classification pipeline using any `AutoModelForDocumentTokenClassification`. The inputs/outputs are + similar to the Token Classification pipeline; however, the pipeline takes an image (and optional OCR'd + words/boxes) as input instead of text context. + + This Document Token Classification pipeline can currently be loaded from [`pipeline`] using the following task + identifier: `"document-token-classification"`. + + The models that this pipeline can use are models that have been fine-tuned on a Document Token Classification task. + See the up-to-date list of available models on + [huggingface.co/models](https://huggingface.co/models?filter=document-token-classification). + """ + + def __init__(self, args_parser=DocumentTokenClassificationArgumentHandler(), *args, **kwargs): + super().__init__(*args, **kwargs) + self.check_model_type(MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING) + self.image_processor = self.feature_extractor + self.image_processor.apply_ocr = False + self._args_parser = args_parser + if self.model.config.model_type == "layoutlmv3": + self.model_type = ModelType.LayoutLMv3 + else: + raise ValueError(f"Model type {self.model.config.model_type} is not supported by this pipeline.") + + def _sanitize_parameters( + self, + padding=None, + doc_stride=None, + lang: Optional[str] = None, + tesseract_config: Optional[str] = None, + max_seq_len=None, + **kwargs, + ): + preprocess_params, postprocess_params = {}, {} + if padding is not None: + preprocess_params["padding"] = padding + if doc_stride is not None: + preprocess_params["doc_stride"] = doc_stride + if max_seq_len is not None: + preprocess_params["max_seq_len"] = max_seq_len + if lang is not None: + preprocess_params["lang"] = lang + if tesseract_config is not None: + preprocess_params["tesseract_config"] = tesseract_config + + return preprocess_params, {}, postprocess_params + + def __call__( + self, + inputs: Union["Image.Image", List["Image.Image"], str, Dict, List[dict]], + **kwargs, + ): + """ + Classifies the list of tokens (word_boxes) given a document. A document is defined as an image and an + optional list of (word, box) tuples which represent the text in the document. If the `word_boxes` are not + provided, it will use the Tesseract OCR engine (if available) to extract the words and boxes automatically for + LayoutLM-like models which require them as input. + + You can invoke the pipeline several ways: + + - `pipeline(inputs=image)` + - `pipeline(inputs=[image])` + - `pipeline(inputs={"image": image})` + - `pipeline(inputs={"image": image, "word_boxes": word_boxes})` + - `pipeline(inputs={"image": image, "words": words, "boxes": boxes})` + - `pipeline(inputs=[{"image": image}])` + - `pipeline(inputs=[{"image": image, "word_boxes": word_boxes}])` + - `pipeline(inputs=[{"image": image, "words": words, "boxes": boxes}])` + + Args: + inputs (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image`, :obj:`List[PIL.Image]`, :obj:`Dict`, :obj:`List[Dict]`): + + Return: + A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys: + + - **words** (:obj:`List[str]`) -- The words in the document. + - **boxes** (:obj:`List[List[int]]`) -- The boxes of the words in the document. + - **word_labels** (:obj:`List[str]`) -- The predicted labels for each word. + """ + inputs = self._args_parser(inputs) + output = super().__call__(inputs, **kwargs) + if isinstance(output, list) and len(output) == 1: + return output[0] + return output + + def preprocess(self, input, lang=None, tesseract_config="", **kwargs): + image = None + if isinstance(input, str) or isinstance(input, Image.Image): + image = load_image(input) + input = {"image": image} + elif input.get("image", None) is not None: + image = load_image(input["image"]) + + words, boxes = None, None + self.image_processor.apply_ocr = False + if "words" in input and "boxes" in input: + words = input["words"] + boxes = input["boxes"] + elif "word_boxes" in input: + words = [x[0] for x in input["word_boxes"]] + boxes = [x[1] for x in input["word_boxes"]] + elif image is not None and not TESSERACT_LOADED: + raise ValueError( + "`word_boxes` not supplied and pytesseract not available to run OCR" + ) + else: + self.image_processor.apply_ocr = True + + # first, apply the image processor + features = self.image_processor( + images=image, + return_tensors=self.framework, + **kwargs, + ) + + encoded_inputs = self.tokenizer( + text=words if words is not None else features["words"], + boxes=boxes if boxes is not None else features["boxes"], + return_tensors=self.framework, + **kwargs, + ) + + if self.model_type == ModelType.LayoutLMv3: + image_field = "pixel_values" + else: + raise ValueError(f"Model type {self.model.config.model_type} is not supported by this pipeline.") + encoded_inputs[image_field] = features.pop("pixel_values") + + # Fields that help with post-processing + encoded_inputs["word_ids"] = encoded_inputs.word_ids() + encoded_inputs["words"] = words if words is not None else features["words"] + encoded_inputs["boxes"] = boxes if boxes is not None else features["boxes"] + + return encoded_inputs + + def _forward(self, model_inputs): + word_ids = model_inputs.pop("word_ids", None) + words = model_inputs.pop("words", None) + boxes = model_inputs.pop("boxes", None) + + model_outputs = self.model(**model_inputs) + + model_outputs["word_ids"] = word_ids + model_outputs["words"] = words + model_outputs["boxes"] = boxes + return model_outputs + + def postprocess(self, model_outputs, **kwargs): + model_outputs = dict(model_outputs) + logits = np.asarray(model_outputs.pop("logits", None)) + words = model_outputs["words"] + boxes = model_outputs["boxes"] + + # if first dimension is 1, remove it + if logits.shape[0] == 1: + logits = logits[0] + + # if words is a list of list of strings, get the first one + if isinstance(words, list) and len(words) != 0 and isinstance(words[0], list): + words = words[0] + model_outputs["words"] = words + + if isinstance(boxes, list) and len(boxes) != 0 and isinstance(boxes[0], list): + boxes = boxes[0] + model_outputs["boxes"] = boxes + + token_predictions = logits.argmax(-1) + + word_ids = model_outputs.pop("word_ids", None) + + # Map Token predictions to word predictions + word_predictions = [None] * len(words) + for word_id, token_prediction in zip(word_ids, token_predictions): + if word_id is not None and word_predictions[word_id] is None: + word_predictions[word_id] = token_prediction + elif word_id is not None and word_predictions[word_id] != token_prediction: + # If conflict, we take the first prediction + pass + + word_labels = [self.model.config.id2label[prediction] for prediction in word_predictions] + model_outputs["word_labels"] = word_labels + return model_outputs diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 178a0b5ae6e559..ceaf47706b1220 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -398,6 +398,9 @@ def __init__(self, *args, **kwargs): MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = None +MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING = None + + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None @@ -529,6 +532,12 @@ class AutoModelForDocumentQuestionAnswering(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AutoModelForDocumentTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + class AutoModelForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 624e08b88e9e31..c7cd01e764df32 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -275,6 +275,9 @@ def __init__(self, *args, **kwargs): TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = None +TF_MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING = None + + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None @@ -343,6 +346,12 @@ class TFAutoModelForDocumentQuestionAnswering(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFAutoModelForDocumentTokenClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + class TFAutoModelForImageClassification(metaclass=DummyObject): _backends = ["tf"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 4a44c15b22150e..24f1aef1ebe7ae 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -38,6 +38,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES, @@ -75,6 +76,7 @@ def _generate_supported_model_class_names( "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, + "document-token-classification": MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES, "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, @@ -753,6 +755,7 @@ def _generate_dummy_input( elif model_class_name in [ *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES), + *get_values(MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING_NAMES), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES), @@ -774,6 +777,9 @@ def _generate_dummy_input( image_size = model.config.vision_config.image_size elif hasattr(model.config, "encoder"): image_size = model.config.encoder.image_size + elif getattr(model.config, "model_type")=="layoutlmv3": + image_size = getattr(model.config, "input_size") + image_size = (image_size, image_size) else: image_size = (_generate_random_int(), _generate_random_int()) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index c06bd644c6391b..edcc0e425bf22c 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -81,7 +81,7 @@ def get_checkpoint_from_architecture(architecture): try: - module = importlib.import_module(architecture.__module__) + module = importlib.import_module(str(architecture.__module__)) except ImportError: logger.error(f"Ignoring architecture {architecture}") return diff --git a/tests/pipelines/test_pipelines_document_token_classification.py b/tests/pipelines/test_pipelines_document_token_classification.py new file mode 100644 index 00000000000000..76e4b3107a5d69 --- /dev/null +++ b/tests/pipelines/test_pipelines_document_token_classification.py @@ -0,0 +1,201 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING, AutoTokenizer, AutoFeatureExtractor, is_vision_available, AutoConfig, AutoModelForDocumentTokenClassification +from transformers.pipelines import pipeline +from transformers.models.layoutlmv3.image_processing_layoutlmv3 import apply_tesseract as apply_ocr +from transformers.testing_utils import ( + nested_simplify, + require_pytesseract, + require_tf, + require_torch, + require_vision, + require_detectron2, + slow, +) + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +if is_vision_available(): + from PIL import Image + + from transformers.image_utils import load_image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + def load_image(_): + return None + + +# This is a pinned image from a specific revision of a document question answering space, hosted by HuggingFace, +# so we can expect it to be available. +INVOICE_URL = ( + "https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png" +) + + +@require_torch +@require_vision +class DocumentTokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_DOCUMENT_TOKEN_CLASSIFICATION_MAPPING + + + @require_pytesseract + @require_vision + def get_test_pipeline(self, model, tokenizer, feature_extractor): + dtc_pipeline = pipeline( + "document-token-classification", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + ) + + image = INVOICE_URL + word_boxes = list(zip(*apply_ocr(load_image(image), None, ""))) + examples = [ + { + "image": load_image(image), + }, + { + "image": image, + }, + { + "image": image, + "word_boxes": word_boxes, + }, + ] + return dtc_pipeline, examples + + def run_pipeline_test(self, dtc_pipeline, examples): + outputs = dtc_pipeline(examples) + self.assertEqual( + outputs, + [ + {"words": ANY(list), "word_labels": ANY(list), "boxes": ANY(list)} for _ in examples + ] + ) + + @require_torch + @require_pytesseract + def test_small_model_pt(self): + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LayoutLMv3ForTokenClassification") + config_ms= AutoConfig.from_pretrained("microsoft/layoutlmv3-base") + config.update(config_ms.to_dict()) + model = AutoModelForDocumentTokenClassification.from_config(config) + tokenizer = AutoTokenizer.from_pretrained( + "microsoft/layoutlmv3-base", revision="07c9b08", add_prefix_space=True + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "microsoft/layoutlmv3-base", revision="07c9b08" + ) + dtc_pipeline = pipeline("document-token-classification", + model=model, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + ) + image = INVOICE_URL + outputs = dtc_pipeline(inputs=image) + self.assertEqual(len(outputs["words"]), 95) + self.assertEqual(len(outputs["word_labels"]), 95) + self.assertEqual(len(outputs["boxes"]), 95) + self.assertEqual(set(outputs["word_labels"]), set(['LABEL_0', 'LABEL_1'])) + + outputs = dtc_pipeline({"image": image}) + self.assertEqual(len(outputs["words"]), 95) + self.assertEqual(len(outputs["word_labels"]), 95) + self.assertEqual(len(outputs["boxes"]), 95) + self.assertEqual(set(outputs["word_labels"]), set(['LABEL_0', 'LABEL_1'])) + + # No text detected -> empty list + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + outputs = dtc_pipeline(inputs=image) + self.assertEqual(outputs["words"], []) + self.assertEqual(outputs["boxes"], []) + self.assertEqual(outputs["word_labels"], []) + + # We can pass the words and bounding boxes directly + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + words = [] + boxes = [] + outputs = dtc_pipeline({"image":image, "words":words, "boxes":boxes}) + self.assertEqual(outputs["words"], []) + self.assertEqual(outputs["boxes"], []) + self.assertEqual(outputs["word_labels"], []) + + + @slow + @require_torch + @require_pytesseract + @require_vision + def test_large_model_pt_layoutlm(self): + dtc_pipeline = pipeline( + "document-token-classification", + model="Theivaprakasham/layoutlmv3-finetuned-invoice", + ) + image = INVOICE_URL + + outputs = dtc_pipeline(inputs=image) + self.assertEqual(len(outputs["words"]), 95) + self.assertEqual(len(outputs["word_labels"]), 95) + self.assertEqual(len(outputs["boxes"]), 95) + self.assertEqual(set(outputs["word_labels"]), {'B-BILLER_POST_CODE', 'B-BILLER', 'B-GST', 'O', 'B-TOTAL'}) + self.assertEqual(outputs["word_labels"].count("B-BILLER_POST_CODE"), 2) + self.assertEqual(outputs["word_labels"].count("B-BILLER"), 2) + self.assertEqual(outputs["word_labels"].count("B-GST"), 7) + self.assertEqual(outputs["word_labels"].count("O"), 80) + self.assertEqual(outputs["word_labels"].count("B-TOTAL"), 4) + + + outputs = dtc_pipeline({"image": image}) + self.assertEqual(len(outputs["words"]), 95) + self.assertEqual(len(outputs["word_labels"]), 95) + self.assertEqual(len(outputs["boxes"]), 95) + self.assertEqual(set(outputs["word_labels"]), {'B-BILLER_POST_CODE', 'B-BILLER', 'B-GST', 'O', 'B-TOTAL'}) + self.assertEqual(outputs["word_labels"].count("B-BILLER_POST_CODE"), 2) + self.assertEqual(outputs["word_labels"].count("B-BILLER"), 2) + self.assertEqual(outputs["word_labels"].count("B-GST"), 7) + self.assertEqual(outputs["word_labels"].count("O"), 80) + self.assertEqual(outputs["word_labels"].count("B-TOTAL"), 4) + + outputs = dtc_pipeline( + [{"image": image}, {"image": image}] + ) + self.assertEqual(len(outputs[0]["words"]), 95) + self.assertEqual(len(outputs[0]["word_labels"]), 95) + self.assertEqual(len(outputs[0]["boxes"]), 95) + self.assertEqual(set(outputs[0]["word_labels"]), {'B-BILLER_POST_CODE', 'B-BILLER', 'B-GST', 'O', 'B-TOTAL'}) + self.assertEqual(outputs[0]["word_labels"].count("B-BILLER_POST_CODE"), 2) + self.assertEqual(outputs[0]["word_labels"].count("B-BILLER"), 2) + self.assertEqual(outputs[0]["word_labels"].count("B-GST"), 7) + self.assertEqual(outputs[0]["word_labels"].count("O"), 80) + self.assertEqual(outputs[0]["word_labels"].count("B-TOTAL"), 4) + + self.assertEqual(len(outputs[1]["words"]), 95) + self.assertEqual(len(outputs[1]["word_labels"]), 95) + self.assertEqual(len(outputs[1]["boxes"]), 95) + self.assertEqual(set(outputs[1]["word_labels"]), {'B-BILLER_POST_CODE', 'B-BILLER', 'B-GST', 'O', 'B-TOTAL'}) + self.assertEqual(outputs[1]["word_labels"].count("B-BILLER_POST_CODE"), 2) + self.assertEqual(outputs[1]["word_labels"].count("B-BILLER"), 2) + self.assertEqual(outputs[1]["word_labels"].count("B-GST"), 7) + self.assertEqual(outputs[1]["word_labels"].count("O"), 80) + self.assertEqual(outputs[1]["word_labels"].count("B-TOTAL"), 4) + + @require_tf + @unittest.skip("Document Token Classification not implemented in TF") + def test_small_model_tf(self): + pass