Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class LayoutLMv2ForRelationExtraction #15173

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).

__version__ = "4.16.0.dev0"

__version__ = "4.16.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think you need to change this line.


from typing import TYPE_CHECKING

Expand Down Expand Up @@ -1029,6 +1028,7 @@
"LayoutLMv2ForQuestionAnswering",
"LayoutLMv2ForSequenceClassification",
"LayoutLMv2ForTokenClassification",
"LayoutLMv2ForRelationExtraction",
"LayoutLMv2Model",
"LayoutLMv2PreTrainedModel",
]
Expand Down Expand Up @@ -2984,6 +2984,7 @@
LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForSequenceClassification,
LayoutLMv2ForTokenClassification,
LayoutLMv2ForRelationExtraction,
LayoutLMv2Model,
LayoutLMv2PreTrainedModel,
)
Expand Down
40 changes: 39 additions & 1 deletion src/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict

import torch

Expand Down Expand Up @@ -812,3 +812,41 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class ReOutput(ModelOutput):
"""
Base class for outputs of relation extraction models.
Comment on lines +817 to +819
Copy link
Contributor

@NielsRogge NielsRogge Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ReOutput(ModelOutput):
"""
Base class for outputs of relation extraction models.
class LayoutLMv2RelationExtractionOutput(ModelOutput):
"""
Class for outputs of [`LayoutLMv2ForRelationExtraction`].

Let's give it a clearer name.

Also, this model output class can be placed at the top of modeling_layoutlmv2.py, since it's quite specific to LayoutLMv2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 2)`):
Classification scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
entities: (Dict), the content of dict is {"start":[], "end":[], "label":[]}.
"start"/"end" correspond the start/end index of text token in all tokens of image
label is in [0, 1, 2], which correspond to ["HEADER", "QUESTION", "ANSWER"]
relations: (Dict), the content of dict is {"head":[], "tail":[], "start_index":[], "end_index":[]}.
"head"/"tail" correspond the entity index in all entities of image.
"start_index"/"end_index" is the min/max value of "head" and "tail" entity's "start" and "end".
pred_relations: (Dict), the content of dict is {"head_id": int, "head": tuple(int, int), "head_type":int, "tail": tuple(int, int)
"tail_id":int, "tail_typ":int, "type":int}
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
entities: Optional[Dict] = None
relations: Optional[Dict] = None
pred_relations: Optional[Dict] = None
2 changes: 2 additions & 0 deletions src/transformers/models/layoutlmv2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"LayoutLMv2ForQuestionAnswering",
"LayoutLMv2ForSequenceClassification",
"LayoutLMv2ForTokenClassification",
"LayoutLMv2ForRelationExtraction",
"LayoutLMv2Layer",
"LayoutLMv2Model",
"LayoutLMv2PreTrainedModel",
Expand All @@ -61,6 +62,7 @@
LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForSequenceClassification,
LayoutLMv2ForTokenClassification,
LayoutLMv2ForRelationExtraction,
LayoutLMv2Layer,
LayoutLMv2Model,
LayoutLMv2PreTrainedModel,
Expand Down
82 changes: 82 additions & 0 deletions src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
ReOutput,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model output class can be defined within the modeling file itself.

)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...utils import logging
from .configuration_layoutlmv2 import LayoutLMv2Config
from .re import REDecoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can include the decoder inside modeling_layoutlmv2.py.

Our philosophy is a single model = a single script.



# soft dependency
Expand Down Expand Up @@ -1226,6 +1228,86 @@ def forward(
)


class LayoutLMv2ForRelationExtraction(LayoutLMv2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.layoutlmv2 = LayoutLMv2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.extractor = REDecoder(config)
# Initialize weights and apply final processing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Initialize weights and apply final processing
# Initialize weights and apply final processing

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv2.embeddings.word_embeddings

@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=ReOutput, config_class=_CONFIG_FOR_DOC)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@replace_return_docstrings(output_type=ReOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=LayoutLMv2RelationExtractionOutput, config_class=_CONFIG_FOR_DOC)

def forward(
self,
input_ids,
bbox,
entities,
relations,
image=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
):
r"""
entities: (Dict), the content of dict is {"start":[], "end":[], "label":[]}.
"start"/"end" correspond the start/end index of text token in all tokens of image
label is in [0, 1, 2], which correspond to ["HEADER", "QUESTION", "ANSWER"]
relations: (Dict), the content of dict is {"head":[], "tail":[], "start_index":[], "end_index":[]}.
"head"/"tail" correspond the entity index in all entities of image.
"start_index"/"end_index" is the min/max value of "head" and "tail" entity's "start" and "end".
Returns:

Examples:

```python
>>> from transformers import LayoutLMv2Processor, LayoutLMv2Tokenizer, LayoutLMv2ForRelationExtraction
>>> from PIL import Image
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
>>> model = LayoutLMv2ForRelationExtraction.from_pretrained("microsoft/layoutlmv2-base-uncased")

>>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
>>> words = ["hello", "world"]
>>> boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes

>>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
>>> print('encoding image: ', encoding['image'].shape)
>>> # encoding['relations'] and encoding['entities'] are taken from the result of ner
>>> encoding['relations'] = [{"end_index": [], "head":[], "start_index":[], "tail":[]}]
>>> encoding['entities'] = {"start": [], "end": [], "label": []}
>>> outputs = model(**encoding)
>>> print(outputs)
```"""

outputs = self.layoutlmv2(
input_ids=input_ids,
bbox=bbox,
image=image,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
)

seq_length = input_ids.size(1)
sequence_output, image_output = outputs[0][:, :seq_length], outputs[0][:, seq_length:]
sequence_output = self.dropout(sequence_output)
loss, pred_relations = self.extractor(sequence_output, entities, relations)

return ReOutput(
loss=loss,
entities=entities,
relations=relations,
pred_relations=pred_relations,
hidden_states=outputs[0],
)


@add_start_docstrings(
"""
LayoutLMv2 Model with a span classification head on top for extractive question-answering tasks such as
Expand Down
154 changes: 154 additions & 0 deletions src/transformers/models/layoutlmv2/re.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import copy

import torch
from torch import nn
from torch.nn import CrossEntropyLoss


class BiaffineAttention(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class BiaffineAttention(torch.nn.Module):
class LayoutLMv2BiaffineAttention(torch.nn.Module):

We usually add a model-specific prefix to each class.

"""Implements a biaffine attention operator for binary relation classification.

PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
as a classifier for binary relation classification.

Args:
in_features (int): The size of the feature dimension of the inputs.
out_features (int): The size of the feature dimension of the output.

Shape:
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
additional dimensisons.
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
additional dimensions.
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
of additional dimensions.

Examples:
>>> batch_size, in_features, out_features = 32, 100, 4
>>> biaffine_attention = BiaffineAttention(in_features, out_features)
>>> x_1 = torch.randn(batch_size, in_features)
>>> x_2 = torch.randn(batch_size, in_features)
>>> output = biaffine_attention(x_1, x_2)
>>> print(output.size())
torch.Size([32, 4])
"""

def __init__(self, in_features, out_features):
super(BiaffineAttention, self).__init__()

self.in_features = in_features
self.out_features = out_features

self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)

self.reset_parameters()

def forward(self, x_1, x_2):
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))

def reset_parameters(self):
self.bilinear.reset_parameters()
self.linear.reset_parameters()


class REDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for butting in, but do you know of any reason to hard code 3 for the num_embeddings parameter (aside from it being the value in unilm)? I came across this PR in prepping to implement a version of this same model for a relation extraction task--so thank you for working on this in the first place! All of that to say, I'm asking for selfish reasons 😅

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw I'm guessing the LayoutXLM authors used 3 because they were dealing with forms that only had 3 semantic entity classes (headers, keys, and values).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes in that case, you can replace it with config.num_labels.

projection = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
)
self.ffnn_head = copy.deepcopy(projection)
self.ffnn_tail = copy.deepcopy(projection)
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
self.loss_fct = CrossEntropyLoss()

def build_relation(self, relations, entities):
batch_size = len(relations)
new_relations = []
for b in range(batch_size):
if len(entities[b]["start"]) <= 2:
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
all_possible_relations = set(
[
(i, j)
for i in range(len(entities[b]["label"]))
for j in range(len(entities[b]["label"]))
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
]
)
if len(all_possible_relations) == 0:
all_possible_relations = set([(0, 1)])
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
negative_relations = all_possible_relations - positive_relations
positive_relations = set([i for i in positive_relations if i in all_possible_relations])
reordered_relations = list(positive_relations) + list(negative_relations)
relation_per_doc = {"head": [], "tail": [], "label": []}
relation_per_doc["head"] = [i[0] for i in reordered_relations]
relation_per_doc["tail"] = [i[1] for i in reordered_relations]
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
len(reordered_relations) - len(positive_relations)
)
assert len(relation_per_doc["head"]) != 0
new_relations.append(relation_per_doc)
return new_relations, entities

def get_predicted_relations(self, logits, relations, entities):
pred_relations = []
for i, pred_label in enumerate(logits.argmax(-1)):
if pred_label != 1:
continue
rel = {}
rel["head_id"] = relations["head"][i]
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
rel["head_type"] = entities["label"][rel["head_id"]]

rel["tail_id"] = relations["tail"][i]
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
rel["tail_type"] = entities["label"][rel["tail_id"]]
rel["type"] = 1
pred_relations.append(rel)
return pred_relations

def forward(self, hidden_states, entities, relations):
batch_size, max_n_words, context_dim = hidden_states.size()
device = hidden_states.device
relations, entities = self.build_relation(relations, entities)
loss = 0
all_pred_relations = []
for b in range(batch_size):
head_entities = torch.tensor(relations[b]["head"], device=device)
tail_entities = torch.tensor(relations[b]["tail"], device=device)
relation_labels = torch.tensor(relations[b]["label"], device=device)
entities_start_index = torch.tensor(entities[b]["start"], device=device)
entities_labels = torch.tensor(entities[b]["label"], device=device)
head_index = entities_start_index[head_entities]
head_label = entities_labels[head_entities]
head_label_repr = self.entity_emb(head_label)

tail_index = entities_start_index[tail_entities]
tail_label = entities_labels[tail_entities]
tail_label_repr = self.entity_emb(tail_label)

head_repr = torch.cat(
(hidden_states[b][head_index], head_label_repr),
dim=-1,
)
tail_repr = torch.cat(
(hidden_states[b][tail_index], tail_label_repr),
dim=-1,
)
heads = self.ffnn_head(head_repr)
tails = self.ffnn_tail(tail_repr)
logits = self.rel_classifier(heads, tails)
loss += self.loss_fct(logits, relation_labels)
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
all_pred_relations.append(pred_relations)
return loss, all_pred_relations
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class LayoutLMv2ForRelationExtraction:
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class LayoutLMv2Model(metaclass=DummyObject):
_backends = ["torch"]

Expand Down