Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LayoutLMv2ForRelationExtraction #19120

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/layoutlmv2.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,7 @@ print(encoding.keys())
## LayoutLMv2ForQuestionAnswering

[[autodoc]] LayoutLMv2ForQuestionAnswering

## LayoutLMv2ForRelationExtraction

[[autodoc]] LayoutLMv2ForRelationExtraction
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,7 @@
"LayoutLMv2ForTokenClassification",
"LayoutLMv2Model",
"LayoutLMv2PreTrainedModel",
"LayoutLMv2ForRelationExtraction",
]
)
_import_structure["models.layoutlmv3"].extend(
Expand Down Expand Up @@ -4080,6 +4081,7 @@
from .models.layoutlmv2 import (
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForRelationExtraction,
LayoutLMv2ForSequenceClassification,
LayoutLMv2ForTokenClassification,
LayoutLMv2Model,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/layoutlmv2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"LayoutLMv2Layer",
"LayoutLMv2Model",
"LayoutLMv2PreTrainedModel",
"LayoutLMv2ForRelationExtraction",
]

if TYPE_CHECKING:
Expand Down Expand Up @@ -95,6 +96,7 @@
from .modeling_layoutlmv2 import (
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForRelationExtraction,
LayoutLMv2ForSequenceClassification,
LayoutLMv2ForTokenClassification,
LayoutLMv2Layer,
Expand Down
286 changes: 285 additions & 1 deletion src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.
""" PyTorch LayoutLMv2 model."""

import copy
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
Expand All @@ -33,6 +35,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, torch_int_div
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_detectron2_available,
Expand Down Expand Up @@ -61,6 +64,41 @@
]


@dataclass
class RelationExtractionOutput(ModelOutput):
"""
Class for outputs of [`LayoutLMv2ForRelationExtraction`].

Args:
loss (`torch.FloatTensor` of shape `(1,)`:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of PyTorch loss functions are 0d tensors, so this is not accurate.

Classification (or regression if config.num_labels==1) loss.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the code, it's always a classification loss, so this needs to be adapted.

entities (`list[dict]`):
List of dictionaries (one per example in the batch). Each dictionary contains 3 keys: `start`, `end` and
`label`.
relations (`list[dict]`):
List of dictionaries (one per example in the batch). Each dictionary contains 4 keys: `start_index`,
`end_index`, `head` and `tail`.
pred_relations (`list[dict]`):
List of dictionaries (one per example in the batch). Each dictionary contains 7 keys: `head`, `head_id`,
`head_type`, `tail`, `tail_id`, `tail_type` and `type`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""

loss: Optional[torch.FloatTensor] = None
pred_relations: dict = None
entities: dict = None
relations: dict = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


class LayoutLMv2Embeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

Expand Down Expand Up @@ -511,7 +549,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Bilinear)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
Expand Down Expand Up @@ -1424,3 +1462,249 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class BiaffineAttention(torch.nn.Module):
"""Implements a biaffine attention operator for binary relation classification.

PyTorch implementation of the biaffine attention operator from "End-to-end neural relation extraction using deep
biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used as a classifier for binary relation
classification.

Args:
in_features (int):
The size of the feature dimension of the inputs.
out_features (int):
The size of the feature dimension of the output.

Shape:
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of additional
dimensisons.
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of additional
dimensions.
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
of additional dimensions.
Comment on lines +1481 to +1486
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No one-letter variables please. Especially not in the documentation which we don't mind being long. Here just say (batch_size, *, features) where * means any number of additional dimensisons.

"""

def __init__(self, in_features, out_features):
super(BiaffineAttention, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super(BiaffineAttention, self).__init__()
super().__init__()

Python 2 was dead two years ago...


self.in_features = in_features
self.out_features = out_features

self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)

def forward(self, x_1, x_2):
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))


class LayoutLMv2RelationExtractionDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True)
projection = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
)
self.ffnn_head = copy.deepcopy(projection)
self.ffnn_tail = copy.deepcopy(projection)
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
self.loss_fct = CrossEntropyLoss()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to hard-code this here. Just use it when necessary in the forward.


def build_relation(self, relations, entities):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a preprocessing method. It shouldn't be part of the model.

batch_size = len(relations)
new_relations = []
for b in range(batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b -> idx might be clearer.

if len(entities[b]["start"]) <= 2:
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
all_possible_relations = set(
[
(i, j)
for i in range(len(entities[b]["label"]))
for j in range(len(entities[b]["label"]))
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
]
)
if len(all_possible_relations) == 0:
all_possible_relations = set([(0, 1)])
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
negative_relations = all_possible_relations - positive_relations
positive_relations = set([i for i in positive_relations if i in all_possible_relations])
reordered_relations = list(positive_relations) + list(negative_relations)
relation_per_doc = {"head": [], "tail": [], "label": []}
relation_per_doc["head"] = [i[0] for i in reordered_relations]
relation_per_doc["tail"] = [i[1] for i in reordered_relations]
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
len(reordered_relations) - len(positive_relations)
)
assert len(relation_per_doc["head"]) != 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new assert in the code base please. Use a test and raise the appropriate error.

new_relations.append(relation_per_doc)
return new_relations, entities

def get_predicted_relations(self, logits, relations, entities):
pred_relations = []
for i, pred_label in enumerate(logits.argmax(-1)):
if pred_label != 1:
continue
rel = {}
rel["head_id"] = relations["head"][i]
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
rel["head_type"] = entities["label"][rel["head_id"]]

rel["tail_id"] = relations["tail"][i]
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
rel["tail_type"] = entities["label"][rel["tail_id"]]
rel["type"] = 1
pred_relations.append(rel)
return pred_relations

def forward(self, hidden_states, entities, relations):
batch_size, max_n_words, context_dim = hidden_states.size()
device = hidden_states.device
relations, entities = self.build_relation(relations, entities)
loss = 0
all_pred_relations = []
for b in range(batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, consider b -> idx

head_entities = torch.tensor(relations[b]["head"], device=device)
tail_entities = torch.tensor(relations[b]["tail"], device=device)
relation_labels = torch.tensor(relations[b]["label"], device=device)
entities_start_index = torch.tensor(entities[b]["start"], device=device)
entities_labels = torch.tensor(entities[b]["label"], device=device)
Comment on lines +1573 to +1577
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's up to the user to provide tensors as inputs, and on the right device, this should just extract the keys.

head_index = entities_start_index[head_entities]
head_label = entities_labels[head_entities]
head_label_repr = self.entity_emb(head_label)

tail_index = entities_start_index[tail_entities]
tail_label = entities_labels[tail_entities]
tail_label_repr = self.entity_emb(tail_label)

head_repr = torch.cat(
(hidden_states[b][head_index], head_label_repr),
dim=-1,
)
Comment on lines +1586 to +1589
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
head_repr = torch.cat(
(hidden_states[b][head_index], head_label_repr),
dim=-1,
)
head_repr = torch.cat((hidden_states[b][head_index], head_label_repr), dim=-1)

Fits in one line.

tail_repr = torch.cat(
(hidden_states[b][tail_index], tail_label_repr),
dim=-1,
)
Comment on lines +1590 to +1593
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tail_repr = torch.cat(
(hidden_states[b][tail_index], tail_label_repr),
dim=-1,
)
taill_repr = torch.cat((hidden_states[b][tail_index], tail_label_repr), dim=-1)

heads = self.ffnn_head(head_repr)
tails = self.ffnn_tail(tail_repr)
logits = self.rel_classifier(heads, tails)
loss += self.loss_fct(logits, relation_labels)
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
all_pred_relations.append(pred_relations)
return loss, all_pred_relations


@add_start_docstrings(
"""
LayoutLMv2 Model with a relation extraction head on top for key-value extraction tasks such as
[XFUND](https://github.com/doc-analysis/XFUND) (a bi-affine attention layer on top).
""",
LAYOUTLMV2_START_DOCSTRING,
)
class LayoutLMv2ForRelationExtraction(LayoutLMv2PreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.layoutlmv2 = LayoutLMv2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.extractor = LayoutLMv2RelationExtractionDecoder(config)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=RelationExtractionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.LongTensor] = None,
image: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
entities: Optional[dict] = None,
relations: Optional[dict] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
entities (`list[dict]`):
List of dictionaries (one per example in the batch). Each dictionary contains 3 keys: `start`, `end` and
`label`.
relations (`list[dict]`):
List of dictionaries (one per example in the batch). Each dictionary contains 4 keys: `start_index`,
`end_index`, `head` and `tail`.

Returns:

Example:

```python
>>> from transformers import LayoutLMv2Processor, LayoutLMv2ForRelationExtraction
>>> from PIL import Image
>>> from datasets import load_dataset

>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2ForRelationExtraction.from_pretrained("microsoft/layoutlmv2-base-uncased")

>>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
>>> image_path = dataset["test"][0]["file"]
>>> image = Image.open(image_path).convert("RGB")
>>> encoding = processor(image, return_tensors="pt")

>>> # instantiate relations as empty at inference time
>>> encoding["entities"] = [{"start": [0, 4], "end": [3, 6], "label": [2, 1]}]
>>> encoding["relations"] = [{"start_index": [], "end_index": [], "head": [], "tail": []}]
Comment on lines +1665 to +1666
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model cant take lists as inputs as they then won't work with ONNX/distributed etc. This should all be tensors.


>>> outputs = model(**encoding)
>>> predicted_relations = outputs.pred_relations[0]
```
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.layoutlmv2(
input_ids=input_ids,
bbox=bbox,
image=image,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

seq_length = input_ids.size(1) if input_ids is not None else inputs_embeds.size(1)
text_output = outputs[0][:, :seq_length]
text_output = self.dropout(text_output)

if entities is None or relations is None:
raise ValueError(
"You need to provide entities and relations. Instantiate relations with empty lists at inference time"
)
loss, pred_relations = self.extractor(text_output, entities, relations)

if not return_dict:
output = (loss, pred_relations, entities, relations) + outputs[2:]
return output

return RelationExtractionOutput(
loss=loss,
pred_relations=pred_relations,
entities=entities,
relations=relations,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2628,6 +2628,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class LayoutLMv2ForRelationExtraction(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class LayoutLMv2ForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading