Skip to content

Commit

Permalink
Merge pull request #18 from ikuyamada/adding_luke_v3
Browse files Browse the repository at this point in the history
Addressing comments on the pull request
  • Loading branch information
NielsRogge authored Apr 20, 2021
2 parents 4f4fcfa + 729dc61 commit 27d0beb
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 589 deletions.
36 changes: 31 additions & 5 deletions docs/source/model_doc/luke.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
..
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
Expand Down Expand Up @@ -38,13 +38,39 @@ answering).*

Tips:


- This implementation is the same as :class:`~transformers.RobertaModel` with the addition of entity embeddings as well
as an entity- aware self-attention mechanism, which improves performance on tasks involving reasoning about entities.
as an entity-aware self-attention mechanism, which improves performance on tasks involving reasoning about entities.
- LUKE adds :obj:`entity_ids`, :obj:`entity_attention_mask`, :obj:`entity_token_type_ids` and
:obj:`entity_position_ids` as extra input to the model. You can obtain those using :class:`LukeTokenizer`.
:obj:`entity_position_ids` as extra input to the model. Input entities can be special entities (e.g., [MASK]) or
Wikipedia entities (e.g., New York City). You can obtain those using :class:`~transformers.LukeTokenizer`.

Example:

.. code-block::
>>> from transformers import LukeTokenizer, LukeModel
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base")
>>> model = LukeModel.from_pretrained("studio-ousia/luke-base")
# Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"
>>> text = "Beyoncé lives in New York."
>>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
>>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
# Input Wikipedia entities to obtain enriched contextualized representations.
>>> text = "Beyoncé lives in New York."
>>> entities = ["Beyoncé", "New York City"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "New York"
>>> entity_spans = [(0, 7), (17, 25)] # character-based entity spans corresponding to "Beyoncé" and "New York"
>>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
The original code can be found `here <https://github.com/studio-ousia/luke>`_.
The original code can be found `here <https://github.com/studio-ousia/luke>`__.


LukeConfig
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
List options
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
Expand Down Expand Up @@ -391,7 +389,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the ``return_unused_kwargs`` keyword parameter.
Examples::
>>> from transformers import AutoConfig
Expand Down
82 changes: 71 additions & 11 deletions src/transformers/models/luke/configuration_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" LUKE configuration """

from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig
from ...configuration_utils import PretrainedConfig


logger = logging.get_logger(__name__)
Expand All @@ -26,49 +26,109 @@
}


class LukeConfig(RobertaConfig):
class LukeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.LukeModel`. It is used to
instantiate a LUKE model according to the specified arguments, defining the model architecture. Configuration
objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the
documentation from :class:`~transformers.PretrainedConfig` for more information. The
:class:`~transformers.LukeConfig` class directly inherits :class:`~transformers.RobertaConfig`. It reuses the same
defaults. Please check the parent class for more information.
instantiate a LUKE model according to the specified arguments, defining the model architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
entity_vocab_size (:obj:`int`, `optional`, defaults to 500000):
Entity vocabulary size of the LUKE model. Defines the number of different entities that can be represented
by the :obj:`entity_ids` passed when calling :class:`~transformers.LukeModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
entity_emb_size (:obj:`int`, `optional`, defaults to 256):
The number of dimensions of the entity embedding.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`):
Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.)
<https://arxiv.org/abs/2010.01057>`__.
Examples::
>>> from transformers import LukeConfig, LukeModel
>>> # Initializing a LUKE configuration
>>> configuration = LukeConfig()
>>> # Initializing a model from the configuration
>>> model = LukeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "luke"

def __init__(
self,
vocab_size: int = 50267,
entity_vocab_size: int = 500000,
entity_emb_size: int = 256,
vocab_size=50267,
entity_vocab_size=500000,
hidden_size=768,
entity_emb_size=256,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
gradient_checkpointing=False,
use_entity_aware_attention=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**kwargs
):
"""Constructs LukeConfig."""
super(LukeConfig, self).__init__(vocab_size=vocab_size, **kwargs)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

self.vocab_size = vocab_size
self.entity_vocab_size = entity_vocab_size
self.hidden_size = hidden_size
self.entity_emb_size = entity_emb_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.use_entity_aware_attention = use_entity_aware_attention
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2]))
config.vocab_size += 2

print("Saving tokenizer to {}".format(pytorch_dump_folder_path))
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
tokenizer.save_pretrained(pytorch_dump_folder_path)
with open(os.path.join(pytorch_dump_folder_path, LukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f:
json.dump(entity_vocab, f)
Expand All @@ -61,7 +61,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
# Initialize the query layers of the entity-aware self-attention mechanism
for layer_index in range(config.num_hidden_layers):
for matrix_name in ["query.weight", "query.bias"]:
prefix = "encoder.layer." + str(layer_index) + ".attention.self."
prefix = f"encoder.layer.{layer_index}.attention.self."
state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name]
Expand Down
48 changes: 33 additions & 15 deletions src/transformers/models/luke/modeling_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):

class LukeEntityEmbeddings(nn.Module):
def __init__(self, config: LukeConfig):
super(LukeEntityEmbeddings, self).__init__()
super().__init__()
self.config = config

self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0)
Expand Down Expand Up @@ -366,21 +366,26 @@ def forward(
value_layer = self.transpose_for_scores(self.value(concat_hidden_states))

if self.use_entity_aware_attention and entity_hidden_states is not None:
# compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e)
# query layers
w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states))
w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states))
e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states))
e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states))

# compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above
w2w_key_layer = key_layer[:, :, :word_size, :]
e2w_key_layer = key_layer[:, :, :word_size, :]
w2e_key_layer = key_layer[:, :, word_size:, :]
e2e_key_layer = key_layer[:, :, word_size:, :]

# compute attention scores based on the dot product between the query and key vectors
w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2))
w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2))
e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2))
e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2))

# combine attention scores to create the final attention score matrix
word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3)
entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3)
attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2)
Expand Down Expand Up @@ -825,7 +830,7 @@ def set_entity_embeddings(self, value):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError("LUKE does not support the pruning of attention heads")

@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BaseLukeModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
Expand Down Expand Up @@ -860,7 +865,7 @@ def forward(
>>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.word_last_hidden_state
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
# Input Wikipedia entities to obtain enriched contextualized representations.
Expand All @@ -870,7 +875,7 @@ def forward(
>>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.word_last_hidden_state
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -919,7 +924,7 @@ def forward(
)

# Second, compute extended attention mask
extended_attention_mask = self._compute_extended_attention_mask(attention_mask, entity_attention_mask)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask)

# Third, compute entity embeddings and concatenate with word embeddings
if entity_ids is None:
Expand All @@ -945,10 +950,7 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseLukeModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand All @@ -959,9 +961,19 @@ def forward(
entity_hidden_states=encoder_outputs.entity_hidden_states,
)

def _compute_extended_attention_mask(
self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor]
):
def get_extended_attention_mask(self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor]):
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
word_attention_mask (:obj:`torch.LongTensor`):
Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
entity_attention_mask (:obj:`torch.LongTensor`, `optional`):
Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
attention_mask = word_attention_mask
if entity_attention_mask is not None:
attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1)
Expand All @@ -971,7 +983,7 @@ def _compute_extended_attention_mask(
elif attention_mask.dim() == 2:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError("Wrong shape for attention_mask (shape {})".format(attention_mask.shape))
raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")

extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
Expand Down Expand Up @@ -1079,6 +1091,8 @@ def forward(

loss = None
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 1:
loss = F.cross_entropy(logits, labels)
else:
Expand Down Expand Up @@ -1186,6 +1200,8 @@ def forward(

loss = None
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 1:
loss = F.cross_entropy(logits, labels)
else:
Expand Down Expand Up @@ -1250,10 +1266,10 @@ def forward(
return_dict=None,
):
r"""
entity_start_positions:
entity_start_positions (:obj:`torch.LongTensor`):
The start positions of entities in the word token sequence.
entity_end_positions:
entity_end_positions (:obj:`torch.LongTensor`):
The end positions of entities in the word token sequence.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)` or :obj:`(batch_size, entity_length, num_labels), `optional`):
Expand Down Expand Up @@ -1308,6 +1324,8 @@ def forward(

loss = None
if labels is not None:
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 2:
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
else:
Expand Down
Loading

0 comments on commit 27d0beb

Please sign in to comment.