-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add class LayoutLMv2ForRelationExtraction #15173
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,7 +13,7 @@ | |||||||||||||
# limitations under the License. | ||||||||||||||
|
||||||||||||||
from dataclasses import dataclass | ||||||||||||||
from typing import Optional, Tuple | ||||||||||||||
from typing import Optional, Tuple, Dict | ||||||||||||||
|
||||||||||||||
import torch | ||||||||||||||
|
||||||||||||||
|
@@ -812,3 +812,41 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): | |||||||||||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None | ||||||||||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||||||||||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | ||||||||||||||
|
||||||||||||||
@dataclass | ||||||||||||||
class ReOutput(ModelOutput): | ||||||||||||||
""" | ||||||||||||||
Base class for outputs of relation extraction models. | ||||||||||||||
Comment on lines
+817
to
+819
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Let's give it a clearer name. Also, this model output class can be placed at the top of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, can you use this extensive docstring: https://github.com/R0bk/transformers/blob/9c0e0ba9ccc0d32b795c2c0e0130931b92230292/src/transformers/models/layoutlmv2/outputs_layoutlmv2.py#L26-L74 Thanks! |
||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : | ||||||||||||||
Classification loss. | ||||||||||||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 2)`): | ||||||||||||||
Classification scores (before SoftMax). | ||||||||||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||||||||||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of | ||||||||||||||
shape `(batch_size, sequence_length, hidden_size)`. | ||||||||||||||
|
||||||||||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | ||||||||||||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||||||||||||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||||||||||||||
sequence_length)`. | ||||||||||||||
|
||||||||||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | ||||||||||||||
heads. | ||||||||||||||
entities: (Dict), the content of dict is {"start":[], "end":[], "label":[]}. | ||||||||||||||
"start"/"end" correspond the start/end index of text token in all tokens of image | ||||||||||||||
label is in [0, 1, 2], which correspond to ["HEADER", "QUESTION", "ANSWER"] | ||||||||||||||
relations: (Dict), the content of dict is {"head":[], "tail":[], "start_index":[], "end_index":[]}. | ||||||||||||||
"head"/"tail" correspond the entity index in all entities of image. | ||||||||||||||
"start_index"/"end_index" is the min/max value of "head" and "tail" entity's "start" and "end". | ||||||||||||||
pred_relations: (Dict), the content of dict is {"head_id": int, "head": tuple(int, int), "head_type":int, "tail": tuple(int, int) | ||||||||||||||
"tail_id":int, "tail_typ":int, "type":int} | ||||||||||||||
""" | ||||||||||||||
loss: Optional[torch.FloatTensor] = None | ||||||||||||||
logits: torch.FloatTensor = None | ||||||||||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||||||||||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None | ||||||||||||||
entities: Optional[Dict] = None | ||||||||||||||
relations: Optional[Dict] = None | ||||||||||||||
pred_relations: Optional[Dict] = None |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -36,10 +36,12 @@ | |||||||
QuestionAnsweringModelOutput, | ||||||||
SequenceClassifierOutput, | ||||||||
TokenClassifierOutput, | ||||||||
ReOutput, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This model output class can be defined within the modeling file itself. |
||||||||
) | ||||||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward | ||||||||
from ...utils import logging | ||||||||
from .configuration_layoutlmv2 import LayoutLMv2Config | ||||||||
from .re import REDecoder | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can include the decoder inside Our philosophy is a single model = a single script. |
||||||||
|
||||||||
|
||||||||
# soft dependency | ||||||||
|
@@ -1226,6 +1228,86 @@ def forward( | |||||||
) | ||||||||
|
||||||||
|
||||||||
class LayoutLMv2ForRelationExtraction(LayoutLMv2PreTrainedModel): | ||||||||
def __init__(self, config): | ||||||||
super().__init__(config) | ||||||||
self.layoutlmv2 = LayoutLMv2Model(config) | ||||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||||||
self.extractor = REDecoder(config) | ||||||||
# Initialize weights and apply final processing | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
self.post_init() | ||||||||
|
||||||||
def get_input_embeddings(self): | ||||||||
return self.layoutlmv2.embeddings.word_embeddings | ||||||||
|
||||||||
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | ||||||||
@replace_return_docstrings(output_type=ReOutput, config_class=_CONFIG_FOR_DOC) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
def forward( | ||||||||
self, | ||||||||
input_ids, | ||||||||
bbox, | ||||||||
entities, | ||||||||
relations, | ||||||||
image=None, | ||||||||
attention_mask=None, | ||||||||
token_type_ids=None, | ||||||||
position_ids=None, | ||||||||
head_mask=None, | ||||||||
): | ||||||||
r""" | ||||||||
entities: (Dict), the content of dict is {"start":[], "end":[], "label":[]}. | ||||||||
"start"/"end" correspond the start/end index of text token in all tokens of image | ||||||||
label is in [0, 1, 2], which correspond to ["HEADER", "QUESTION", "ANSWER"] | ||||||||
relations: (Dict), the content of dict is {"head":[], "tail":[], "start_index":[], "end_index":[]}. | ||||||||
"head"/"tail" correspond the entity index in all entities of image. | ||||||||
"start_index"/"end_index" is the min/max value of "head" and "tail" entity's "start" and "end". | ||||||||
Returns: | ||||||||
|
||||||||
Examples: | ||||||||
|
||||||||
```python | ||||||||
>>> from transformers import LayoutLMv2Processor, LayoutLMv2Tokenizer, LayoutLMv2ForRelationExtraction | ||||||||
>>> from PIL import Image | ||||||||
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr") | ||||||||
>>> model = LayoutLMv2ForRelationExtraction.from_pretrained("microsoft/layoutlmv2-base-uncased") | ||||||||
|
||||||||
>>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB") | ||||||||
>>> words = ["hello", "world"] | ||||||||
>>> boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes | ||||||||
|
||||||||
>>> encoding = processor(image, words, boxes=boxes, return_tensors="pt") | ||||||||
>>> print('encoding image: ', encoding['image'].shape) | ||||||||
>>> # encoding['relations'] and encoding['entities'] are taken from the result of ner | ||||||||
>>> encoding['relations'] = [{"end_index": [], "head":[], "start_index":[], "tail":[]}] | ||||||||
>>> encoding['entities'] = {"start": [], "end": [], "label": []} | ||||||||
>>> outputs = model(**encoding) | ||||||||
>>> print(outputs) | ||||||||
```""" | ||||||||
|
||||||||
outputs = self.layoutlmv2( | ||||||||
input_ids=input_ids, | ||||||||
bbox=bbox, | ||||||||
image=image, | ||||||||
attention_mask=attention_mask, | ||||||||
token_type_ids=token_type_ids, | ||||||||
position_ids=position_ids, | ||||||||
head_mask=head_mask, | ||||||||
) | ||||||||
|
||||||||
seq_length = input_ids.size(1) | ||||||||
sequence_output, image_output = outputs[0][:, :seq_length], outputs[0][:, seq_length:] | ||||||||
sequence_output = self.dropout(sequence_output) | ||||||||
loss, pred_relations = self.extractor(sequence_output, entities, relations) | ||||||||
|
||||||||
return ReOutput( | ||||||||
loss=loss, | ||||||||
entities=entities, | ||||||||
relations=relations, | ||||||||
pred_relations=pred_relations, | ||||||||
hidden_states=outputs[0], | ||||||||
) | ||||||||
|
||||||||
|
||||||||
@add_start_docstrings( | ||||||||
""" | ||||||||
LayoutLMv2 Model with a span classification head on top for extractive question-answering tasks such as | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,154 @@ | ||||||
import copy | ||||||
|
||||||
import torch | ||||||
from torch import nn | ||||||
from torch.nn import CrossEntropyLoss | ||||||
|
||||||
|
||||||
class BiaffineAttention(torch.nn.Module): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We usually add a model-specific prefix to each class. |
||||||
"""Implements a biaffine attention operator for binary relation classification. | ||||||
|
||||||
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation | ||||||
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used | ||||||
as a classifier for binary relation classification. | ||||||
|
||||||
Args: | ||||||
in_features (int): The size of the feature dimension of the inputs. | ||||||
out_features (int): The size of the feature dimension of the output. | ||||||
|
||||||
Shape: | ||||||
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of | ||||||
additional dimensisons. | ||||||
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of | ||||||
additional dimensions. | ||||||
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number | ||||||
of additional dimensions. | ||||||
|
||||||
Examples: | ||||||
>>> batch_size, in_features, out_features = 32, 100, 4 | ||||||
>>> biaffine_attention = BiaffineAttention(in_features, out_features) | ||||||
>>> x_1 = torch.randn(batch_size, in_features) | ||||||
>>> x_2 = torch.randn(batch_size, in_features) | ||||||
>>> output = biaffine_attention(x_1, x_2) | ||||||
>>> print(output.size()) | ||||||
torch.Size([32, 4]) | ||||||
""" | ||||||
|
||||||
def __init__(self, in_features, out_features): | ||||||
super(BiaffineAttention, self).__init__() | ||||||
|
||||||
self.in_features = in_features | ||||||
self.out_features = out_features | ||||||
|
||||||
self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False) | ||||||
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True) | ||||||
|
||||||
self.reset_parameters() | ||||||
|
||||||
def forward(self, x_1, x_2): | ||||||
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1)) | ||||||
|
||||||
def reset_parameters(self): | ||||||
self.bilinear.reset_parameters() | ||||||
self.linear.reset_parameters() | ||||||
|
||||||
|
||||||
class REDecoder(nn.Module): | ||||||
def __init__(self, config): | ||||||
super().__init__() | ||||||
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for butting in, but do you know of any reason to hard code 3 for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fwiw I'm guessing the LayoutXLM authors used 3 because they were dealing with forms that only had 3 semantic entity classes (headers, keys, and values). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes in that case, you can replace it with |
||||||
projection = nn.Sequential( | ||||||
nn.Linear(config.hidden_size * 2, config.hidden_size), | ||||||
nn.ReLU(), | ||||||
nn.Dropout(config.hidden_dropout_prob), | ||||||
nn.Linear(config.hidden_size, config.hidden_size // 2), | ||||||
nn.ReLU(), | ||||||
nn.Dropout(config.hidden_dropout_prob), | ||||||
) | ||||||
self.ffnn_head = copy.deepcopy(projection) | ||||||
self.ffnn_tail = copy.deepcopy(projection) | ||||||
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2) | ||||||
self.loss_fct = CrossEntropyLoss() | ||||||
|
||||||
def build_relation(self, relations, entities): | ||||||
batch_size = len(relations) | ||||||
new_relations = [] | ||||||
for b in range(batch_size): | ||||||
if len(entities[b]["start"]) <= 2: | ||||||
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]} | ||||||
all_possible_relations = set( | ||||||
[ | ||||||
(i, j) | ||||||
for i in range(len(entities[b]["label"])) | ||||||
for j in range(len(entities[b]["label"])) | ||||||
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 | ||||||
] | ||||||
) | ||||||
if len(all_possible_relations) == 0: | ||||||
all_possible_relations = set([(0, 1)]) | ||||||
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"]))) | ||||||
negative_relations = all_possible_relations - positive_relations | ||||||
positive_relations = set([i for i in positive_relations if i in all_possible_relations]) | ||||||
reordered_relations = list(positive_relations) + list(negative_relations) | ||||||
relation_per_doc = {"head": [], "tail": [], "label": []} | ||||||
relation_per_doc["head"] = [i[0] for i in reordered_relations] | ||||||
relation_per_doc["tail"] = [i[1] for i in reordered_relations] | ||||||
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * ( | ||||||
len(reordered_relations) - len(positive_relations) | ||||||
) | ||||||
assert len(relation_per_doc["head"]) != 0 | ||||||
new_relations.append(relation_per_doc) | ||||||
return new_relations, entities | ||||||
|
||||||
def get_predicted_relations(self, logits, relations, entities): | ||||||
pred_relations = [] | ||||||
for i, pred_label in enumerate(logits.argmax(-1)): | ||||||
if pred_label != 1: | ||||||
continue | ||||||
rel = {} | ||||||
rel["head_id"] = relations["head"][i] | ||||||
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]]) | ||||||
rel["head_type"] = entities["label"][rel["head_id"]] | ||||||
|
||||||
rel["tail_id"] = relations["tail"][i] | ||||||
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]]) | ||||||
rel["tail_type"] = entities["label"][rel["tail_id"]] | ||||||
rel["type"] = 1 | ||||||
pred_relations.append(rel) | ||||||
return pred_relations | ||||||
|
||||||
def forward(self, hidden_states, entities, relations): | ||||||
batch_size, max_n_words, context_dim = hidden_states.size() | ||||||
device = hidden_states.device | ||||||
relations, entities = self.build_relation(relations, entities) | ||||||
loss = 0 | ||||||
all_pred_relations = [] | ||||||
for b in range(batch_size): | ||||||
head_entities = torch.tensor(relations[b]["head"], device=device) | ||||||
tail_entities = torch.tensor(relations[b]["tail"], device=device) | ||||||
relation_labels = torch.tensor(relations[b]["label"], device=device) | ||||||
entities_start_index = torch.tensor(entities[b]["start"], device=device) | ||||||
entities_labels = torch.tensor(entities[b]["label"], device=device) | ||||||
head_index = entities_start_index[head_entities] | ||||||
head_label = entities_labels[head_entities] | ||||||
head_label_repr = self.entity_emb(head_label) | ||||||
|
||||||
tail_index = entities_start_index[tail_entities] | ||||||
tail_label = entities_labels[tail_entities] | ||||||
tail_label_repr = self.entity_emb(tail_label) | ||||||
|
||||||
head_repr = torch.cat( | ||||||
(hidden_states[b][head_index], head_label_repr), | ||||||
dim=-1, | ||||||
) | ||||||
tail_repr = torch.cat( | ||||||
(hidden_states[b][tail_index], tail_label_repr), | ||||||
dim=-1, | ||||||
) | ||||||
heads = self.ffnn_head(head_repr) | ||||||
tails = self.ffnn_tail(tail_repr) | ||||||
logits = self.rel_classifier(heads, tails) | ||||||
loss += self.loss_fct(logits, relation_labels) | ||||||
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b]) | ||||||
all_pred_relations.append(pred_relations) | ||||||
return loss, all_pred_relations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think you need to change this line.