From 2a2bf09292e70d4306a185cbb114347ac4dea373 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Tue, 6 Sep 2022 07:42:38 -0700 Subject: [PATCH] Address comments --- src/transformers/pipelines/__init__.py | 4 +-- .../pipelines/document_question_answering.py | 32 ++++++++----------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index ca8102d57695fa..e3f9e603b5111d 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -222,9 +222,7 @@ "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (), "tf": (), "default": { - "model": { - "pt": ("impira/layoutlm-document-qa", "3a93017") - }, # TODO Update with custom pipeline removed, just before we land + "model": {"pt": ("impira/layoutlm-document-qa", "3a93017")}, }, "type": "multimodal", }, diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index 3329ce2dc48103..b0fe18cb9dd6c2 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -120,6 +120,8 @@ def __init__(self, *args, **kwargs): if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig": self.model_type = ModelType.VisionEncoderDecoder + if self.model.config.encoder.model_type != "donut-swin": + raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut") elif self.model.config.__class__.__name__ == "LayoutLMConfig": self.model_type = ModelType.LayoutLM else: @@ -240,12 +242,7 @@ def __call__( inputs = image return super().__call__(inputs, **kwargs) - def preprocess( - self, - input, - lang=None, - tesseract_config="", - ): + def preprocess(self, input, lang=None, tesseract_config=""): image = None image_features = {} if input.get("image", None) is not None: @@ -342,14 +339,14 @@ def preprocess( if "boxes" not in tokenizer_kwargs: bbox = [] for batch_index in range(num_spans): - for i, s, w in zip( + for input_id, sequence_id, word_id in zip( encoding.input_ids[batch_index], encoding.sequence_ids(batch_index), encoding.word_ids(batch_index), ): - if s == 1: - bbox.append(boxes[w]) - elif i == self.tokenizer.sep_token_id: + if sequence_id == 1: + bbox.append(boxes[word_id]) + elif input_id == self.tokenizer.sep_token_id: bbox.append([1000] * 4) else: bbox.append([0] * 4) @@ -361,12 +358,7 @@ def preprocess( word_ids = [encoding.word_ids(i) for i in range(num_spans)] - return { - **encoding, - "p_mask": p_mask, - "word_ids": word_ids, - "words": words, - } + return {**encoding, "p_mask": p_mask, "word_ids": word_ids, "words": words} def _forward(self, model_inputs): p_mask = model_inputs.pop("p_mask", None) @@ -396,8 +388,10 @@ def postprocess(self, model_outputs, top_k=1, **kwargs): return answers def postprocess_donut(self, model_outputs, **kwargs): - # postprocess sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] + + # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer + # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context). sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token ret = { @@ -431,8 +425,8 @@ def postprocess_extractive_qa( ) word_ids = model_outputs["word_ids"][0] - for s, e, score in zip(starts, ends, scores): - word_start, word_end = word_ids[s], word_ids[e] + for start, eend, score in zip(starts, ends, scores): + word_start, word_end = word_ids[start], word_ids[eend] if word_start is not None and word_end is not None: answers.append( {