-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LayoutLMForQuestionAnswering model #18407
Conversation
@Narsil I've left a few TODOs -- (1) supporting tensorflow, (2) filling in docs, (3) filling in tests -- which I'll gladly do. I just wanted to post sooner than later to start getting feedback on the approach. |
The documentation is not available anymore as the PR was closed or merged. |
Ok, for this part I will let @NielsRogge comment as I am not the best person to answer how it should be done. |
@NielsRogge @Narsil gentle nudge on this PR. I plan to fix the tests + write docs as a next step but wanted to get some quick feedback about whether this approach is acceptable for including |
src/transformers/__init__.py
Outdated
@@ -2314,6 +2315,7 @@ | |||
"TFLayoutLMForMaskedLM", | |||
"TFLayoutLMForSequenceClassification", | |||
"TFLayoutLMForTokenClassification", | |||
# XXX "TFLayoutLMForQuestionAnswering", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be done in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's do it in a separate PR. I'll remove these commented out values.
src/transformers/__init__.py
Outdated
@@ -4525,7 +4528,7 @@ | |||
) | |||
from .generation_tf_utils import tf_top_k_top_p_filtering | |||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback | |||
from .modeling_tf_layoutlm import ( | |||
from .modeling_tf_layoutlm import ( # TODO TFLayoutLMForQuestionAnswering, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here.
@@ -104,7 +107,7 @@ | |||
except OptionalDependencyNotAvailable: | |||
pass | |||
else: | |||
from .modeling_tf_layoutlm import ( | |||
from .modeling_tf_layoutlm import ( # TODO LayoutLMForQuestionAnswering, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here.
if input_ids is not None: | ||
input_shape = input_ids.size() | ||
else: | ||
input_shape = inputs_embeds.size()[:-1] | ||
|
||
seq_length = input_shape[1] | ||
# only take the text part of the output representations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be done for LayoutLMv1, you can just use outputs[0], see LayoutLMForTokenClassification.
Unlike LayoutLMv2 and v3, the first version only forwards text tokens through the Transformer encoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a first brief review.
I'm OK with supporting LayoutLM for the VQA pipeline, although it's entirely different to ViLT (which is the only model supported by the pipeline for now).
LayoutLM solves it as an extractive task (SQuAD-like), predicting start and end positions, so we can probably borrow a lot from the existing QA pipeline. ViLT on the other hand solves it as a multi-label classification problem.
Thanks @NielsRogge! We're discussing the pipeline part in pull request 18414. Would love your feedback there too! |
@NielsRogge @Narsil I just updated it to include tests+documentation. If it's okay, I'd like to defer the tensorflow implementation for now (due to some personal lack of familiarity). I am failing a consistency check, however, as a result:
Could you help me resolve this? |
@NielsRogge @Narsil, I went ahead and implemented support for TensorFlow and the checks are now passing. Would appreciate a re-review. |
@NielsRogge gentle nudge on this PR :) |
@add_start_docstrings( | ||
""" | ||
LayoutLM Model with a span classification head on top for extractive question-answering tasks such as | ||
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to | |
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
Example: | ||
|
||
In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us | ||
a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example: | |
In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us | |
a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image). | |
Example: | |
In the example below, we prepare a question + context pair for the LayoutLM model. It will give us | |
a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased", add_prefix_space=True) | ||
>>> model = LayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased") | ||
|
||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if this shouldn't be nielsr/funsd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change
>>> outputs = model(**encoding) | ||
>>> loss = outputs.loss | ||
>>> start_scores = outputs.start_logits | ||
>>> end_scores = outputs.end_logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm the code examples work as expected, it would be great to add LayoutLM (v1) to the doc tests. Details here: https://github.com/huggingface/transformers/tree/main/docs#testing-documentation-examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") | ||
>>> model = TFLayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased") | ||
|
||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks almost ready, main comment is about adding the model to the doc tests.
Thanks @NielsRogge! I just updated with your comments, added to the list of doc tests, and verified locally that they are (now) passing. |
dd25a0c
to
bc2090f
Compare
Thanks @NielsRogge just rebased |
@NielsRogge I believe all outstanding comments have been addressed. Are we ready to merge this in? |
I've pinged @sgugger for a final review, however he's off this week so will be merged next week :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very impressive model contribution! Thanks for your contribution, @ankrgyl
Tried it locally and it seems to work very well 👍
Thank you for merging it in! @LysandreJik or @NielsRogge are you planning to do any sort of announcement? I'm asking because we're going to publicly announce the project we've been working on (https://github.com/impira/docquery) in the next few days, and it would be great to collaborate. |
I'd like to communicate on that once the pipeline is merged, because the Space above is using that right? Also, the doc tests don't seem to pass:
|
Hi @ankrgyl Thanks a lot for adding For the doctest:
Here is how you can run the doctest First python utils/prepare_for_doc_test.py src/transformers/utils/doc.py Then for python utils/prepare_for_doc_test.py src/transformers/models/layoutlm/modeling_layoutlm.py
pytest --doctest-modules src/transformers/models/layoutlm/modeling_layoutlm.py -sv --doctest-continue-on-failure For python utils/prepare_for_doc_test.py src/transformers/models/layoutlm/modeling_tf_layoutlm.py
pytest --doctest-modules src/transformers/models/layoutlm/modeling_tf_layoutlm.py -sv --doctest-continue-on-failure Thank you again! If you have trouble on debugging this, let me know :-) |
Hi @NielsRogge @ydshieh I'm very sorry about that -- what happened is that we've updated the weights on the underlying model and it's returning a different name from the same document (the question itself is slightly ambiguous). I've confirmed that if I pin the revision in the tests, they pass. I've just submitted #18854 to resolve that. I'll investigate the weights in |
@NielsRogge the Space is indeed using the pipeline (and incorporates |
No problem, thanks for the fix.
Great! |
* Add LayoutLMForQuestionAnswering model * Fix output * Remove TF TODOs * Add test cases * Add docs * TF implementation * Fix PT/TF equivalence * Fix loss * make fixup * Fix up documentation code examples * Fix up documentation examples + test them * Remove LayoutLMForQuestionAnswering from the auto mapping * Docstrings * Add better docstrings * Undo whitespace changes * Update tokenizers in comments * Fixup code and remove `from_pt=True` * Fix tests * Revert some unexpected docstring changes * Fix tests by overriding _prepare_for_class Co-authored-by: Ankur Goyal <ankur@impira.com>
What does this PR do?
This PR adds a
LayoutLMForQuestionAnswering
class that follows the implementations ofLayoutLMv2ForQuestionAnswering
andLayoutLMv3ForQuestionAnswering
, so thatLayoutLM
can be fine-tuned for the question answering task.Fixes #18380
Before submitting
Pull Request section?
to it if that's the case: LayoutLM-based visual question answering model, weights, and pipeline #18380
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Narsil