Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl committed Sep 7, 2022
1 parent d271829 commit d24d2db
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
1 change: 0 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@
"default": {
"model": {
"pt": ("impira/layoutlm-document-qa", "3a93017")
}, # TODO Update with custom pipeline removed, just before we land
},
"type": "multimodal",
},
Expand Down
32 changes: 13 additions & 19 deletions src/transformers/pipelines/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __init__(self, *args, **kwargs):

if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
self.model_type = ModelType.VisionEncoderDecoder
if self.model.config.encoder.model_type != "donut-swin":
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
elif self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else:
Expand Down Expand Up @@ -240,12 +242,7 @@ def __call__(
inputs = image
return super().__call__(inputs, **kwargs)

def preprocess(
self,
input,
lang=None,
tesseract_config="",
):
def preprocess(self, input, lang=None, tesseract_config=""):
image = None
image_features = {}
if input.get("image", None) is not None:
Expand Down Expand Up @@ -342,14 +339,14 @@ def preprocess(
if "boxes" not in tokenizer_kwargs:
bbox = []
for batch_index in range(num_spans):
for i, s, w in zip(
for input_id, sequence_id, word_id in zip(
encoding.input_ids[batch_index],
encoding.sequence_ids(batch_index),
encoding.word_ids(batch_index),
):
if s == 1:
bbox.append(boxes[w])
elif i == self.tokenizer.sep_token_id:
if sequence_id == 1:
bbox.append(boxes[word_id])
elif input_id == self.tokenizer.sep_token_id:
bbox.append([1000] * 4)
else:
bbox.append([0] * 4)
Expand All @@ -361,12 +358,7 @@ def preprocess(

word_ids = [encoding.word_ids(i) for i in range(num_spans)]

return {
**encoding,
"p_mask": p_mask,
"word_ids": word_ids,
"words": words,
}
return {**encoding, "p_mask": p_mask, "word_ids": word_ids, "words": words}

def _forward(self, model_inputs):
p_mask = model_inputs.pop("p_mask", None)
Expand Down Expand Up @@ -396,8 +388,10 @@ def postprocess(self, model_outputs, top_k=1, **kwargs):
return answers

def postprocess_donut(self, model_outputs, **kwargs):
# postprocess
sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0]

# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
ret = {
Expand Down Expand Up @@ -431,8 +425,8 @@ def postprocess_extractive_qa(
)

word_ids = model_outputs["word_ids"][0]
for s, e, score in zip(starts, ends, scores):
word_start, word_end = word_ids[s], word_ids[e]
for start, eend, score in zip(starts, ends, scores):
word_start, word_end = word_ids[start], word_ids[eend]
if word_start is not None and word_end is not None:
answers.append(
{
Expand Down

0 comments on commit d24d2db

Please sign in to comment.