-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LayoutLMv2ForRelationExtraction #19120
Changes from all commits
9fb42f3
bfa9fbe
36c41a1
0c05911
219a4f7
4c5c3dc
cf7e2a8
3a904ea
1e844fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,7 +14,9 @@ | |||||||||||
# limitations under the License. | ||||||||||||
""" PyTorch LayoutLMv2 model.""" | ||||||||||||
|
||||||||||||
import copy | ||||||||||||
import math | ||||||||||||
from dataclasses import dataclass | ||||||||||||
from typing import Optional, Tuple, Union | ||||||||||||
|
||||||||||||
import torch | ||||||||||||
|
@@ -33,6 +35,7 @@ | |||||||||||
from ...modeling_utils import PreTrainedModel | ||||||||||||
from ...pytorch_utils import apply_chunking_to_forward, torch_int_div | ||||||||||||
from ...utils import ( | ||||||||||||
ModelOutput, | ||||||||||||
add_start_docstrings, | ||||||||||||
add_start_docstrings_to_model_forward, | ||||||||||||
is_detectron2_available, | ||||||||||||
|
@@ -61,6 +64,41 @@ | |||||||||||
] | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass | ||||||||||||
class RelationExtractionOutput(ModelOutput): | ||||||||||||
""" | ||||||||||||
Class for outputs of [`LayoutLMv2ForRelationExtraction`]. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
loss (`torch.FloatTensor` of shape `(1,)`: | ||||||||||||
Classification (or regression if config.num_labels==1) loss. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reading the code, it's always a classification loss, so this needs to be adapted. |
||||||||||||
entities (`list[dict]`): | ||||||||||||
List of dictionaries (one per example in the batch). Each dictionary contains 3 keys: `start`, `end` and | ||||||||||||
`label`. | ||||||||||||
relations (`list[dict]`): | ||||||||||||
List of dictionaries (one per example in the batch). Each dictionary contains 4 keys: `start_index`, | ||||||||||||
`end_index`, `head` and `tail`. | ||||||||||||
pred_relations (`list[dict]`): | ||||||||||||
List of dictionaries (one per example in the batch). Each dictionary contains 7 keys: `head`, `head_id`, | ||||||||||||
`head_type`, `tail`, `tail_id`, `tail_type` and `type`. | ||||||||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||||||||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||||||||||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of | ||||||||||||
the model at the output of each layer plus the optional initial embedding outputs. | ||||||||||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||||||||||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||||||||||||
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in | ||||||||||||
the self-attention heads. | ||||||||||||
""" | ||||||||||||
|
||||||||||||
loss: Optional[torch.FloatTensor] = None | ||||||||||||
pred_relations: dict = None | ||||||||||||
entities: dict = None | ||||||||||||
relations: dict = None | ||||||||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||||||||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None | ||||||||||||
|
||||||||||||
|
||||||||||||
class LayoutLMv2Embeddings(nn.Module): | ||||||||||||
"""Construct the embeddings from word, position and token_type embeddings.""" | ||||||||||||
|
||||||||||||
|
@@ -511,7 +549,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): | |||||||||||
|
||||||||||||
def _init_weights(self, module): | ||||||||||||
"""Initialize the weights""" | ||||||||||||
if isinstance(module, nn.Linear): | ||||||||||||
if isinstance(module, (nn.Linear, nn.Bilinear)): | ||||||||||||
# Slightly different from the TF version which uses truncated_normal for initialization | ||||||||||||
# cf https://github.com/pytorch/pytorch/pull/5617 | ||||||||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | ||||||||||||
|
@@ -1424,3 +1462,249 @@ def forward( | |||||||||||
hidden_states=outputs.hidden_states, | ||||||||||||
attentions=outputs.attentions, | ||||||||||||
) | ||||||||||||
|
||||||||||||
|
||||||||||||
class BiaffineAttention(torch.nn.Module): | ||||||||||||
"""Implements a biaffine attention operator for binary relation classification. | ||||||||||||
|
||||||||||||
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation extraction using deep | ||||||||||||
biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used as a classifier for binary relation | ||||||||||||
classification. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
in_features (int): | ||||||||||||
The size of the feature dimension of the inputs. | ||||||||||||
out_features (int): | ||||||||||||
The size of the feature dimension of the output. | ||||||||||||
|
||||||||||||
Shape: | ||||||||||||
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of additional | ||||||||||||
dimensisons. | ||||||||||||
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of additional | ||||||||||||
dimensions. | ||||||||||||
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number | ||||||||||||
of additional dimensions. | ||||||||||||
Comment on lines
+1481
to
+1486
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No one-letter variables please. Especially not in the documentation which we don't mind being long. Here just say |
||||||||||||
""" | ||||||||||||
|
||||||||||||
def __init__(self, in_features, out_features): | ||||||||||||
super(BiaffineAttention, self).__init__() | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Python 2 was dead two years ago... |
||||||||||||
|
||||||||||||
self.in_features = in_features | ||||||||||||
self.out_features = out_features | ||||||||||||
|
||||||||||||
self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False) | ||||||||||||
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True) | ||||||||||||
|
||||||||||||
def forward(self, x_1, x_2): | ||||||||||||
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1)) | ||||||||||||
|
||||||||||||
|
||||||||||||
class LayoutLMv2RelationExtractionDecoder(nn.Module): | ||||||||||||
def __init__(self, config): | ||||||||||||
super().__init__() | ||||||||||||
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True) | ||||||||||||
projection = nn.Sequential( | ||||||||||||
nn.Linear(config.hidden_size * 2, config.hidden_size), | ||||||||||||
nn.ReLU(), | ||||||||||||
nn.Dropout(config.hidden_dropout_prob), | ||||||||||||
nn.Linear(config.hidden_size, config.hidden_size // 2), | ||||||||||||
nn.ReLU(), | ||||||||||||
nn.Dropout(config.hidden_dropout_prob), | ||||||||||||
) | ||||||||||||
self.ffnn_head = copy.deepcopy(projection) | ||||||||||||
self.ffnn_tail = copy.deepcopy(projection) | ||||||||||||
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2) | ||||||||||||
self.loss_fct = CrossEntropyLoss() | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to hard-code this here. Just use it when necessary in the forward. |
||||||||||||
|
||||||||||||
def build_relation(self, relations, entities): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a preprocessing method. It shouldn't be part of the model. |
||||||||||||
batch_size = len(relations) | ||||||||||||
new_relations = [] | ||||||||||||
for b in range(batch_size): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. b -> idx might be clearer. |
||||||||||||
if len(entities[b]["start"]) <= 2: | ||||||||||||
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]} | ||||||||||||
all_possible_relations = set( | ||||||||||||
[ | ||||||||||||
(i, j) | ||||||||||||
for i in range(len(entities[b]["label"])) | ||||||||||||
for j in range(len(entities[b]["label"])) | ||||||||||||
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 | ||||||||||||
] | ||||||||||||
) | ||||||||||||
if len(all_possible_relations) == 0: | ||||||||||||
all_possible_relations = set([(0, 1)]) | ||||||||||||
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"]))) | ||||||||||||
negative_relations = all_possible_relations - positive_relations | ||||||||||||
positive_relations = set([i for i in positive_relations if i in all_possible_relations]) | ||||||||||||
reordered_relations = list(positive_relations) + list(negative_relations) | ||||||||||||
relation_per_doc = {"head": [], "tail": [], "label": []} | ||||||||||||
relation_per_doc["head"] = [i[0] for i in reordered_relations] | ||||||||||||
relation_per_doc["tail"] = [i[1] for i in reordered_relations] | ||||||||||||
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * ( | ||||||||||||
len(reordered_relations) - len(positive_relations) | ||||||||||||
) | ||||||||||||
assert len(relation_per_doc["head"]) != 0 | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No new assert in the code base please. Use a test and raise the appropriate error. |
||||||||||||
new_relations.append(relation_per_doc) | ||||||||||||
return new_relations, entities | ||||||||||||
|
||||||||||||
def get_predicted_relations(self, logits, relations, entities): | ||||||||||||
pred_relations = [] | ||||||||||||
for i, pred_label in enumerate(logits.argmax(-1)): | ||||||||||||
if pred_label != 1: | ||||||||||||
continue | ||||||||||||
rel = {} | ||||||||||||
rel["head_id"] = relations["head"][i] | ||||||||||||
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]]) | ||||||||||||
rel["head_type"] = entities["label"][rel["head_id"]] | ||||||||||||
|
||||||||||||
rel["tail_id"] = relations["tail"][i] | ||||||||||||
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]]) | ||||||||||||
rel["tail_type"] = entities["label"][rel["tail_id"]] | ||||||||||||
rel["type"] = 1 | ||||||||||||
pred_relations.append(rel) | ||||||||||||
return pred_relations | ||||||||||||
|
||||||||||||
def forward(self, hidden_states, entities, relations): | ||||||||||||
batch_size, max_n_words, context_dim = hidden_states.size() | ||||||||||||
device = hidden_states.device | ||||||||||||
relations, entities = self.build_relation(relations, entities) | ||||||||||||
loss = 0 | ||||||||||||
all_pred_relations = [] | ||||||||||||
for b in range(batch_size): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, consider b -> idx |
||||||||||||
head_entities = torch.tensor(relations[b]["head"], device=device) | ||||||||||||
tail_entities = torch.tensor(relations[b]["tail"], device=device) | ||||||||||||
relation_labels = torch.tensor(relations[b]["label"], device=device) | ||||||||||||
entities_start_index = torch.tensor(entities[b]["start"], device=device) | ||||||||||||
entities_labels = torch.tensor(entities[b]["label"], device=device) | ||||||||||||
Comment on lines
+1573
to
+1577
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's up to the user to provide tensors as inputs, and on the right device, this should just extract the keys. |
||||||||||||
head_index = entities_start_index[head_entities] | ||||||||||||
head_label = entities_labels[head_entities] | ||||||||||||
head_label_repr = self.entity_emb(head_label) | ||||||||||||
|
||||||||||||
tail_index = entities_start_index[tail_entities] | ||||||||||||
tail_label = entities_labels[tail_entities] | ||||||||||||
tail_label_repr = self.entity_emb(tail_label) | ||||||||||||
|
||||||||||||
head_repr = torch.cat( | ||||||||||||
(hidden_states[b][head_index], head_label_repr), | ||||||||||||
dim=-1, | ||||||||||||
) | ||||||||||||
Comment on lines
+1586
to
+1589
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Fits in one line. |
||||||||||||
tail_repr = torch.cat( | ||||||||||||
(hidden_states[b][tail_index], tail_label_repr), | ||||||||||||
dim=-1, | ||||||||||||
) | ||||||||||||
Comment on lines
+1590
to
+1593
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
heads = self.ffnn_head(head_repr) | ||||||||||||
tails = self.ffnn_tail(tail_repr) | ||||||||||||
logits = self.rel_classifier(heads, tails) | ||||||||||||
loss += self.loss_fct(logits, relation_labels) | ||||||||||||
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b]) | ||||||||||||
all_pred_relations.append(pred_relations) | ||||||||||||
return loss, all_pred_relations | ||||||||||||
|
||||||||||||
|
||||||||||||
@add_start_docstrings( | ||||||||||||
""" | ||||||||||||
LayoutLMv2 Model with a relation extraction head on top for key-value extraction tasks such as | ||||||||||||
[XFUND](https://github.com/doc-analysis/XFUND) (a bi-affine attention layer on top). | ||||||||||||
""", | ||||||||||||
LAYOUTLMV2_START_DOCSTRING, | ||||||||||||
) | ||||||||||||
class LayoutLMv2ForRelationExtraction(LayoutLMv2PreTrainedModel): | ||||||||||||
def __init__(self, config): | ||||||||||||
super().__init__(config) | ||||||||||||
|
||||||||||||
self.layoutlmv2 = LayoutLMv2Model(config) | ||||||||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||||||||||
self.extractor = LayoutLMv2RelationExtractionDecoder(config) | ||||||||||||
|
||||||||||||
# Initialize weights and apply final processing | ||||||||||||
self.post_init() | ||||||||||||
|
||||||||||||
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | ||||||||||||
@replace_return_docstrings(output_type=RelationExtractionOutput, config_class=_CONFIG_FOR_DOC) | ||||||||||||
def forward( | ||||||||||||
self, | ||||||||||||
input_ids: Optional[torch.LongTensor] = None, | ||||||||||||
bbox: Optional[torch.LongTensor] = None, | ||||||||||||
image: Optional[torch.FloatTensor] = None, | ||||||||||||
attention_mask: Optional[torch.FloatTensor] = None, | ||||||||||||
token_type_ids: Optional[torch.LongTensor] = None, | ||||||||||||
position_ids: Optional[torch.LongTensor] = None, | ||||||||||||
head_mask: Optional[torch.FloatTensor] = None, | ||||||||||||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||||||||||||
entities: Optional[dict] = None, | ||||||||||||
relations: Optional[dict] = None, | ||||||||||||
output_attentions: Optional[bool] = None, | ||||||||||||
output_hidden_states: Optional[bool] = None, | ||||||||||||
return_dict: Optional[bool] = None, | ||||||||||||
): | ||||||||||||
r""" | ||||||||||||
entities (`list[dict]`): | ||||||||||||
List of dictionaries (one per example in the batch). Each dictionary contains 3 keys: `start`, `end` and | ||||||||||||
`label`. | ||||||||||||
relations (`list[dict]`): | ||||||||||||
List of dictionaries (one per example in the batch). Each dictionary contains 4 keys: `start_index`, | ||||||||||||
`end_index`, `head` and `tail`. | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
|
||||||||||||
Example: | ||||||||||||
|
||||||||||||
```python | ||||||||||||
>>> from transformers import LayoutLMv2Processor, LayoutLMv2ForRelationExtraction | ||||||||||||
>>> from PIL import Image | ||||||||||||
>>> from datasets import load_dataset | ||||||||||||
|
||||||||||||
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased") | ||||||||||||
>>> model = LayoutLMv2ForRelationExtraction.from_pretrained("microsoft/layoutlmv2-base-uncased") | ||||||||||||
|
||||||||||||
>>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa") | ||||||||||||
>>> image_path = dataset["test"][0]["file"] | ||||||||||||
>>> image = Image.open(image_path).convert("RGB") | ||||||||||||
>>> encoding = processor(image, return_tensors="pt") | ||||||||||||
|
||||||||||||
>>> # instantiate relations as empty at inference time | ||||||||||||
>>> encoding["entities"] = [{"start": [0, 4], "end": [3, 6], "label": [2, 1]}] | ||||||||||||
>>> encoding["relations"] = [{"start_index": [], "end_index": [], "head": [], "tail": []}] | ||||||||||||
Comment on lines
+1665
to
+1666
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Model cant take lists as inputs as they then won't work with ONNX/distributed etc. This should all be tensors. |
||||||||||||
|
||||||||||||
>>> outputs = model(**encoding) | ||||||||||||
>>> predicted_relations = outputs.pred_relations[0] | ||||||||||||
``` | ||||||||||||
""" | ||||||||||||
|
||||||||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||||||||||
|
||||||||||||
outputs = self.layoutlmv2( | ||||||||||||
input_ids=input_ids, | ||||||||||||
bbox=bbox, | ||||||||||||
image=image, | ||||||||||||
attention_mask=attention_mask, | ||||||||||||
token_type_ids=token_type_ids, | ||||||||||||
position_ids=position_ids, | ||||||||||||
head_mask=head_mask, | ||||||||||||
inputs_embeds=inputs_embeds, | ||||||||||||
output_attentions=output_attentions, | ||||||||||||
output_hidden_states=output_hidden_states, | ||||||||||||
return_dict=return_dict, | ||||||||||||
) | ||||||||||||
|
||||||||||||
seq_length = input_ids.size(1) if input_ids is not None else inputs_embeds.size(1) | ||||||||||||
text_output = outputs[0][:, :seq_length] | ||||||||||||
text_output = self.dropout(text_output) | ||||||||||||
|
||||||||||||
if entities is None or relations is None: | ||||||||||||
raise ValueError( | ||||||||||||
"You need to provide entities and relations. Instantiate relations with empty lists at inference time" | ||||||||||||
) | ||||||||||||
loss, pred_relations = self.extractor(text_output, entities, relations) | ||||||||||||
|
||||||||||||
if not return_dict: | ||||||||||||
output = (loss, pred_relations, entities, relations) + outputs[2:] | ||||||||||||
return output | ||||||||||||
|
||||||||||||
return RelationExtractionOutput( | ||||||||||||
loss=loss, | ||||||||||||
pred_relations=pred_relations, | ||||||||||||
entities=entities, | ||||||||||||
relations=relations, | ||||||||||||
hidden_states=outputs.hidden_states, | ||||||||||||
attentions=outputs.attentions, | ||||||||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of PyTorch loss functions are 0d tensors, so this is not accurate.