Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LayoutLMForQuestionAnswering model #18407

Merged
merged 20 commits into from
Aug 31, 2022
Merged

Conversation

ankrgyl
Copy link
Contributor

@ankrgyl ankrgyl commented Aug 1, 2022

What does this PR do?

This PR adds a LayoutLMForQuestionAnswering class that follows the implementations of LayoutLMv2ForQuestionAnswering and LayoutLMv3ForQuestionAnswering, so that LayoutLM can be fine-tuned for the question answering task.

Fixes #18380

Before submitting

Who can review?

@Narsil

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 1, 2022

@Narsil I've left a few TODOs -- (1) supporting tensorflow, (2) filling in docs, (3) filling in tests -- which I'll gladly do. I just wanted to post sooner than later to start getting feedback on the approach.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 1, 2022

The documentation is not available anymore as the PR was closed or merged.

@ankrgyl ankrgyl changed the title Add LayoutLMForQuestionAnswering model [WIP] Add LayoutLMForQuestionAnswering model Aug 2, 2022
@Narsil
Copy link
Contributor

Narsil commented Aug 2, 2022

Ok, for this part I will let @NielsRogge comment as I am not the best person to answer how it should be done.

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 3, 2022

@NielsRogge @Narsil gentle nudge on this PR. I plan to fix the tests + write docs as a next step but wanted to get some quick feedback about whether this approach is acceptable for including LayoutLMForQuestionAnswering. Appreciate your consideration!

@@ -2314,6 +2315,7 @@
"TFLayoutLMForMaskedLM",
"TFLayoutLMForSequenceClassification",
"TFLayoutLMForTokenClassification",
# XXX "TFLayoutLMForQuestionAnswering",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be done in a separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's do it in a separate PR. I'll remove these commented out values.

@@ -4525,7 +4528,7 @@
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import (
from .modeling_tf_layoutlm import ( # TODO TFLayoutLMForQuestionAnswering,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

@@ -104,7 +107,7 @@
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_layoutlm import (
from .modeling_tf_layoutlm import ( # TODO LayoutLMForQuestionAnswering,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

Comment on lines 1276 to 1282
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]
# only take the text part of the output representations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be done for LayoutLMv1, you can just use outputs[0], see LayoutLMForTokenClassification.

Unlike LayoutLMv2 and v3, the first version only forwards text tokens through the Transformer encoder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left a first brief review.

I'm OK with supporting LayoutLM for the VQA pipeline, although it's entirely different to ViLT (which is the only model supported by the pipeline for now).

LayoutLM solves it as an extractive task (SQuAD-like), predicting start and end positions, so we can probably borrow a lot from the existing QA pipeline. ViLT on the other hand solves it as a multi-label classification problem.

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 3, 2022

Thanks @NielsRogge!

We're discussing the pipeline part in pull request 18414. Would love your feedback there too!

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 3, 2022

@NielsRogge @Narsil I just updated it to include tests+documentation. If it's okay, I'd like to defer the tensorflow implementation for now (due to some personal lack of familiarity). I am failing a consistency check, however, as a result:

  File "/Users/ankur/projects/transformers/transformers/utils/check_inits.py", line 298, in <module>
    check_all_inits()
  File "/Users/ankur/projects/transformers/transformers/utils/check_inits.py", line 238, in check_all_inits
    raise ValueError("\n\n".join(failures))
ValueError: Problem in src/transformers/models/layoutlm/__init__.py, both halves do not define the same objects.
Differences for tf backend:
  LayoutLMForQuestionAnswering in _import_structure but not in TYPE_HINT.

Could you help me resolve this?

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 5, 2022

@NielsRogge @Narsil, I went ahead and implemented support for TensorFlow and the checks are now passing. Would appreciate a re-review.

@ankrgyl ankrgyl changed the title [WIP] Add LayoutLMForQuestionAnswering model Add LayoutLMForQuestionAnswering model Aug 5, 2022
@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 9, 2022

@NielsRogge gentle nudge on this PR :)

@add_start_docstrings(
"""
LayoutLM Model with a span classification head on top for extractive question-answering tasks such as
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

Comment on lines 1278 to 1281
Example:

In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Example:
In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
Example:
In the example below, we prepare a question + context pair for the LayoutLM model. It will give us
a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased", add_prefix_space=True)
>>> model = LayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased")

>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if this shouldn't be nielsr/funsd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change

>>> outputs = model(**encoding)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm the code examples work as expected, it would be great to add LayoutLM (v1) to the doc tests. Details here: https://github.com/huggingface/transformers/tree/main/docs#testing-documentation-examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = TFLayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased")

>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks almost ready, main comment is about adding the model to the doc tests.

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 10, 2022

Thanks @NielsRogge! I just updated with your comments, added to the list of doc tests, and verified locally that they are (now) passing.

@ankrgyl ankrgyl force-pushed the layoutlmv1-qa branch 2 times, most recently from dd25a0c to bc2090f Compare August 11, 2022 15:57
@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 30, 2022

Thanks @NielsRogge just rebased

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 30, 2022

@NielsRogge I believe all outstanding comments have been addressed. Are we ready to merge this in?

@NielsRogge
Copy link
Contributor

I've pinged @sgugger for a final review, however he's off this week so will be merged next week :)

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very impressive model contribution! Thanks for your contribution, @ankrgyl

Tried it locally and it seems to work very well 👍

@LysandreJik LysandreJik merged commit 5c4c869 into huggingface:main Aug 31, 2022
@ankrgyl
Copy link
Contributor Author

ankrgyl commented Aug 31, 2022

Thank you for merging it in! @LysandreJik or @NielsRogge are you planning to do any sort of announcement? I'm asking because we're going to publicly announce the project we've been working on (https://github.com/impira/docquery) in the next few days, and it would be great to collaborate.

@NielsRogge
Copy link
Contributor

I'd like to communicate on that once the pipeline is merged, because the Space above is using that right?

Also, the doc tests don't seem to pass:

_ [doctest] transformers.models.layoutlm.modeling_layoutlm.LayoutLMForQuestionAnswering.forward _
1328         ...         bbox.append([0] * 4)
1329         >>> encoding["bbox"] = torch.tensor([bbox])
1330 
1331         >>> word_ids = encoding.word_ids(0)
1332         >>> outputs = model(**encoding)
1333         >>> loss = outputs.loss
1334         >>> start_scores = outputs.start_logits
1335         >>> end_scores = outputs.end_logits
1336         >>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
1337         >>> print(" ".join(words[start : end + 1]))
Expected:
    M. Hamann P. Harper, P. Martinez
Got:
    J. S. Wigand

/__w/transformers/transformers/src/transformers/models/layoutlm/modeling_layoutlm.py:1337: DocTestFailure
_ [doctest] transformers.models.layoutlm.modeling_tf_layoutlm.TFLayoutLMForQuestionAnswering.call _
[15](https://github.com/huggingface/transformers/runs/8125145111?check_suite_focus=true#step:9:16)53         ...         bbox.append([0] * 4)
1554         >>> encoding["bbox"] = tf.convert_to_tensor([bbox])
1555 
1556         >>> word_ids = encoding.word_ids(0)
1557         >>> outputs = model(**encoding)
1558         >>> loss = outputs.loss
1559         >>> start_scores = outputs.start_logits
1560         >>> end_scores = outputs.end_logits
1561         >>> start, end = word_ids[tf.math.argmax(start_scores, -1)[0]], word_ids[tf.math.argmax(end_scores, -1)[0]]
1562         >>> print(" ".join(words[start : end + 1]))
Expected:
    M. Hamann P. Harper, P. Martinez
Got:
    <BLANKLINE>

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 1, 2022

Hi @ankrgyl Thanks a lot for adding (TF)LayoutLMForQuestionAnswering !

For the doctest:

  • TFLayoutLMForQuestionAnswering seems to have issue loading the weights for qa_outputs. Could you check if the TF checkpoint in impira/layoutlm-document-qa has weights for this part, or see if you can find what goes wrong? The warning message is

    Some layers of TFLayoutLMForQuestionAnswering were not initialized from the model checkpoint at impira/layoutlm-document- qa and are newly initialized: ['qa_outputs']

    and I actually got some random results for this test.

  • LayoutLMForQuestionAnswering weight loading looks fine, but the output is different from the expected value. Could you take a look here?

Here is how you can run the doctest

First

python utils/prepare_for_doc_test.py src/transformers/utils/doc.py

Then for LayoutLMForQuestionAnswering:

python utils/prepare_for_doc_test.py src/transformers/models/layoutlm/modeling_layoutlm.py
pytest --doctest-modules src/transformers/models/layoutlm/modeling_layoutlm.py -sv --doctest-continue-on-failure

For TFLayoutLMForQuestionAnswering:

python utils/prepare_for_doc_test.py src/transformers/models/layoutlm/modeling_tf_layoutlm.py
pytest --doctest-modules src/transformers/models/layoutlm/modeling_tf_layoutlm.py -sv --doctest-continue-on-failure

Thank you again! If you have trouble on debugging this, let me know :-)

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Sep 1, 2022

Hi @NielsRogge @ydshieh I'm very sorry about that -- what happened is that we've updated the weights on the underlying model and it's returning a different name from the same document (the question itself is slightly ambiguous).

I've confirmed that if I pin the revision in the tests, they pass. I've just submitted #18854 to resolve that.

I'll investigate the weights in impira/layoutlm-document-qa in parallel.

@ankrgyl
Copy link
Contributor Author

ankrgyl commented Sep 1, 2022

I'd like to communicate on that once the pipeline is merged, because the Space above is using that right?

@NielsRogge the Space is indeed using the pipeline (and incorporates Donut too). It makes sense to do the announcement after that lands. We'll still do ours today but simply mention that we are working to upstream changes. Let me know if y'all have any concerns about that.

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 1, 2022

Hi @NielsRogge @ydshieh I'm very sorry about that -- what happened is that we've updated the weights on the underlying model and it's returning a different name from the same document (the question itself is slightly ambiguous).

No problem, thanks for the fix.

I've confirmed that if I pin the revision in the tests, they pass. I've just submitted #18854 to resolve that.

Great!

oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
* Add LayoutLMForQuestionAnswering model

* Fix output

* Remove TF TODOs

* Add test cases

* Add docs

* TF implementation

* Fix PT/TF equivalence

* Fix loss

* make fixup

* Fix up documentation code examples

* Fix up documentation examples + test them

* Remove LayoutLMForQuestionAnswering from the auto mapping

* Docstrings

* Add better docstrings

* Undo whitespace changes

* Update tokenizers in comments

* Fixup code and remove `from_pt=True`

* Fix tests

* Revert some unexpected docstring changes

* Fix tests by overriding _prepare_for_class

Co-authored-by: Ankur Goyal <ankur@impira.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LayoutLM-based visual question answering model, weights, and pipeline
6 participants