From 71671e41f727a384339389952a7fb4768ce863e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E9=9F=B5?= Date: Mon, 17 Jan 2022 11:10:33 +0800 Subject: [PATCH] add LayoutLMv2ForRelationExtraction to transformers --- src/transformers/__init__.py | 5 +- src/transformers/modeling_outputs.py | 40 ++++- .../models/layoutlmv2/__init__.py | 2 + .../models/layoutlmv2/modeling_layoutlmv2.py | 82 ++++++++++ src/transformers/models/layoutlmv2/re.py | 154 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 12 ++ 6 files changed, 292 insertions(+), 3 deletions(-) create mode 100644 src/transformers/models/layoutlmv2/re.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6a63a49c6a04..1d28c67800c0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -22,8 +22,7 @@ # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # in the namespace without actually importing anything (and especially none of the backends). -__version__ = "4.16.0.dev0" - +__version__ = "4.16.0" from typing import TYPE_CHECKING @@ -1012,6 +1011,7 @@ "LayoutLMv2ForQuestionAnswering", "LayoutLMv2ForSequenceClassification", "LayoutLMv2ForTokenClassification", + "LayoutLMv2ForRelationExtraction", "LayoutLMv2Model", "LayoutLMv2PreTrainedModel", ] @@ -2964,6 +2964,7 @@ LayoutLMv2ForQuestionAnswering, LayoutLMv2ForSequenceClassification, LayoutLMv2ForTokenClassification, + LayoutLMv2ForRelationExtraction, LayoutLMv2Model, LayoutLMv2PreTrainedModel, ) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index c7f4a27fb38c..6553b021cd70 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch @@ -812,3 +812,41 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class ReOutput(ModelOutput): + """ + Base class for outputs of relation extraction models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 2)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + entities: (Dict), the content of dict is {"start":[], "end":[], "label":[]}. + "start"/"end" correspond the start/end index of text token in all tokens of image + label is in [0, 1, 2], which correspond to ["HEADER", "QUESTION", "ANSWER"] + relations: (Dict), the content of dict is {"head":[], "tail":[], "start_index":[], "end_index":[]}. + "head"/"tail" correspond the entity index in all entities of image. + "start_index"/"end_index" is the min/max value of "head" and "tail" entity's "start" and "end". + pred_relations: (Dict), the content of dict is {"head_id": int, "head": tuple(int, int), "head_type":int, "tail": tuple(int, int) + "tail_id":int, "tail_typ":int, "type":int} + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + entities: Optional[Dict] = None + relations: Optional[Dict] = None + pred_relations: Optional[Dict] = None diff --git a/src/transformers/models/layoutlmv2/__init__.py b/src/transformers/models/layoutlmv2/__init__.py index c75d075e59bf..8a85a1b83d6b 100644 --- a/src/transformers/models/layoutlmv2/__init__.py +++ b/src/transformers/models/layoutlmv2/__init__.py @@ -39,6 +39,7 @@ "LayoutLMv2ForQuestionAnswering", "LayoutLMv2ForSequenceClassification", "LayoutLMv2ForTokenClassification", + "LayoutLMv2ForRelationExtraction", "LayoutLMv2Layer", "LayoutLMv2Model", "LayoutLMv2PreTrainedModel", @@ -61,6 +62,7 @@ LayoutLMv2ForQuestionAnswering, LayoutLMv2ForSequenceClassification, LayoutLMv2ForTokenClassification, + LayoutLMv2ForRelationExtraction, LayoutLMv2Layer, LayoutLMv2Model, LayoutLMv2PreTrainedModel, diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 8c5d95b76f36..e187c730bee8 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -36,10 +36,12 @@ QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, + ReOutput, ) from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...utils import logging from .configuration_layoutlmv2 import LayoutLMv2Config +from .re import REDecoder # soft dependency @@ -1226,6 +1228,86 @@ def forward( ) +class LayoutLMv2ForRelationExtraction(LayoutLMv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layoutlmv2 = LayoutLMv2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.extractor = REDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.layoutlmv2.embeddings.word_embeddings + + @add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=ReOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids, + bbox, + entities, + relations, + image=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + ): + r""" + entities: (Dict), the content of dict is {"start":[], "end":[], "label":[]}. + "start"/"end" correspond the start/end index of text token in all tokens of image + label is in [0, 1, 2], which correspond to ["HEADER", "QUESTION", "ANSWER"] + relations: (Dict), the content of dict is {"head":[], "tail":[], "start_index":[], "end_index":[]}. + "head"/"tail" correspond the entity index in all entities of image. + "start_index"/"end_index" is the min/max value of "head" and "tail" entity's "start" and "end". + Returns: + + Examples: + + ```python + >>> from transformers import LayoutLMv2Processor, LayoutLMv2Tokenizer, LayoutLMv2ForRelationExtraction + >>> from PIL import Image + >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr") + >>> model = LayoutLMv2ForRelationExtraction.from_pretrained("microsoft/layoutlmv2-base-uncased") + + >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB") + >>> words = ["hello", "world"] + >>> boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes + + >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt") + >>> print('encoding image: ', encoding['image'].shape) + >>> # encoding['relations'] and encoding['entities'] are taken from the result of ner + >>> encoding['relations'] = [{"end_index": [], "head":[], "start_index":[], "tail":[]}] + >>> encoding['entities'] = {"start": [], "end": [], "label": []} + >>> outputs = model(**encoding) + >>> print(outputs) + ```""" + + outputs = self.layoutlmv2( + input_ids=input_ids, + bbox=bbox, + image=image, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + ) + + seq_length = input_ids.size(1) + sequence_output, image_output = outputs[0][:, :seq_length], outputs[0][:, seq_length:] + sequence_output = self.dropout(sequence_output) + loss, pred_relations = self.extractor(sequence_output, entities, relations) + + return ReOutput( + loss=loss, + entities=entities, + relations=relations, + pred_relations=pred_relations, + hidden_states=outputs[0], + ) + + @add_start_docstrings( """ LayoutLMv2 Model with a span classification head on top for extractive question-answering tasks such as diff --git a/src/transformers/models/layoutlmv2/re.py b/src/transformers/models/layoutlmv2/re.py new file mode 100644 index 000000000000..dd31c57fa556 --- /dev/null +++ b/src/transformers/models/layoutlmv2/re.py @@ -0,0 +1,154 @@ +import copy + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + + +class BiaffineAttention(torch.nn.Module): + """Implements a biaffine attention operator for binary relation classification. + + PyTorch implementation of the biaffine attention operator from "End-to-end neural relation + extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used + as a classifier for binary relation classification. + + Args: + in_features (int): The size of the feature dimension of the inputs. + out_features (int): The size of the feature dimension of the output. + + Shape: + - x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of + additional dimensisons. + - x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of + additional dimensions. + - Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number + of additional dimensions. + + Examples: + >>> batch_size, in_features, out_features = 32, 100, 4 + >>> biaffine_attention = BiaffineAttention(in_features, out_features) + >>> x_1 = torch.randn(batch_size, in_features) + >>> x_2 = torch.randn(batch_size, in_features) + >>> output = biaffine_attention(x_1, x_2) + >>> print(output.size()) + torch.Size([32, 4]) + """ + + def __init__(self, in_features, out_features): + super(BiaffineAttention, self).__init__() + + self.in_features = in_features + self.out_features = out_features + + self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False) + self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True) + + self.reset_parameters() + + def forward(self, x_1, x_2): + return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1)) + + def reset_parameters(self): + self.bilinear.reset_parameters() + self.linear.reset_parameters() + + +class REDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True) + projection = nn.Sequential( + nn.Linear(config.hidden_size * 2, config.hidden_size), + nn.ReLU(), + nn.Dropout(config.hidden_dropout_prob), + nn.Linear(config.hidden_size, config.hidden_size // 2), + nn.ReLU(), + nn.Dropout(config.hidden_dropout_prob), + ) + self.ffnn_head = copy.deepcopy(projection) + self.ffnn_tail = copy.deepcopy(projection) + self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2) + self.loss_fct = CrossEntropyLoss() + + def build_relation(self, relations, entities): + batch_size = len(relations) + new_relations = [] + for b in range(batch_size): + if len(entities[b]["start"]) <= 2: + entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]} + all_possible_relations = set( + [ + (i, j) + for i in range(len(entities[b]["label"])) + for j in range(len(entities[b]["label"])) + if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 + ] + ) + if len(all_possible_relations) == 0: + all_possible_relations = set([(0, 1)]) + positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"]))) + negative_relations = all_possible_relations - positive_relations + positive_relations = set([i for i in positive_relations if i in all_possible_relations]) + reordered_relations = list(positive_relations) + list(negative_relations) + relation_per_doc = {"head": [], "tail": [], "label": []} + relation_per_doc["head"] = [i[0] for i in reordered_relations] + relation_per_doc["tail"] = [i[1] for i in reordered_relations] + relation_per_doc["label"] = [1] * len(positive_relations) + [0] * ( + len(reordered_relations) - len(positive_relations) + ) + assert len(relation_per_doc["head"]) != 0 + new_relations.append(relation_per_doc) + return new_relations, entities + + def get_predicted_relations(self, logits, relations, entities): + pred_relations = [] + for i, pred_label in enumerate(logits.argmax(-1)): + if pred_label != 1: + continue + rel = {} + rel["head_id"] = relations["head"][i] + rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]]) + rel["head_type"] = entities["label"][rel["head_id"]] + + rel["tail_id"] = relations["tail"][i] + rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]]) + rel["tail_type"] = entities["label"][rel["tail_id"]] + rel["type"] = 1 + pred_relations.append(rel) + return pred_relations + + def forward(self, hidden_states, entities, relations): + batch_size, max_n_words, context_dim = hidden_states.size() + device = hidden_states.device + relations, entities = self.build_relation(relations, entities) + loss = 0 + all_pred_relations = [] + for b in range(batch_size): + head_entities = torch.tensor(relations[b]["head"], device=device) + tail_entities = torch.tensor(relations[b]["tail"], device=device) + relation_labels = torch.tensor(relations[b]["label"], device=device) + entities_start_index = torch.tensor(entities[b]["start"], device=device) + entities_labels = torch.tensor(entities[b]["label"], device=device) + head_index = entities_start_index[head_entities] + head_label = entities_labels[head_entities] + head_label_repr = self.entity_emb(head_label) + + tail_index = entities_start_index[tail_entities] + tail_label = entities_labels[tail_entities] + tail_label_repr = self.entity_emb(tail_label) + + head_repr = torch.cat( + (hidden_states[b][head_index], head_label_repr), + dim=-1, + ) + tail_repr = torch.cat( + (hidden_states[b][tail_index], tail_label_repr), + dim=-1, + ) + heads = self.ffnn_head(head_repr) + tails = self.ffnn_tail(tail_repr) + logits = self.rel_classifier(heads, tails) + loss += self.loss_fct(logits, relation_labels) + pred_relations = self.get_predicted_relations(logits, relations[b], entities[b]) + all_pred_relations.append(pred_relations) + return loss, all_pred_relations diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index a5e576b88450..473c82a85aed 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2859,6 +2859,18 @@ def forward(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LayoutLMv2ForRelationExtraction: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LayoutLMv2Model: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"])