Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class LayoutLMv2ForRelationExtraction #15173

Closed
wants to merge 3 commits into from

Conversation

yuan-wenhua
Copy link

What does this PR do?

Add class LayoutLMv2ForRelationExtraction which in https://github.com/microsoft/unilm/tree/master/layoutlmft

class REDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for butting in, but do you know of any reason to hard code 3 for the num_embeddings parameter (aside from it being the value in unilm)? I came across this PR in prepping to implement a version of this same model for a relation extraction task--so thank you for working on this in the first place! All of that to say, I'm asking for selfish reasons 😅

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw I'm guessing the LayoutXLM authors used 3 because they were dealing with forms that only had 3 semantic entity classes (headers, keys, and values).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes in that case, you can replace it with config.num_labels.

__version__ = "4.16.0.dev0"

__version__ = "4.16.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think you need to change this line.

Comment on lines +817 to +819
class ReOutput(ModelOutput):
"""
Base class for outputs of relation extraction models.
Copy link
Contributor

@NielsRogge NielsRogge Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ReOutput(ModelOutput):
"""
Base class for outputs of relation extraction models.
class LayoutLMv2RelationExtractionOutput(ModelOutput):
"""
Class for outputs of [`LayoutLMv2ForRelationExtraction`].

Let's give it a clearer name.

Also, this model output class can be placed at the top of modeling_layoutlmv2.py, since it's quite specific to LayoutLMv2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -36,10 +36,12 @@
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
ReOutput,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model output class can be defined within the modeling file itself.

)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...utils import logging
from .configuration_layoutlmv2 import LayoutLMv2Config
from .re import REDecoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can include the decoder inside modeling_layoutlmv2.py.

Our philosophy is a single model = a single script.

self.layoutlmv2 = LayoutLMv2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.extractor = REDecoder(config)
# Initialize weights and apply final processing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Initialize weights and apply final processing
# Initialize weights and apply final processing

return self.layoutlmv2.embeddings.word_embeddings

@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=ReOutput, config_class=_CONFIG_FOR_DOC)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@replace_return_docstrings(output_type=ReOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=LayoutLMv2RelationExtractionOutput, config_class=_CONFIG_FOR_DOC)

from torch.nn import CrossEntropyLoss


class BiaffineAttention(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class BiaffineAttention(torch.nn.Module):
class LayoutLMv2BiaffineAttention(torch.nn.Module):

We usually add a model-specific prefix to each class.

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! A few comments:

  • the model should be added to the tests (tests/test_modeling_layoutlmv2.py). Normally, if you run make fixup locally, it will automatically complain about the fact that this currently isn't the case.
  • we do have the philosophy of "single model, single file". Hence, you can include the model specific output class, as well as the relation extraction decoder inside modeling_layoutlm_v2.py.

@quasimik
Copy link

I've implemented the requested changes, but I'm not sure how to contribute to this PR from here.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Mar 31, 2022
@NielsRogge
Copy link
Contributor

Hi @quasimik,

could you open up a clean PR such that we can add it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants