Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing comments on the pull request #18

Merged
merged 5 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions docs/source/model_doc/luke.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
..
..
Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
Expand Down Expand Up @@ -38,13 +38,39 @@ answering).*

Tips:


- This implementation is the same as :class:`~transformers.RobertaModel` with the addition of entity embeddings as well
as an entity- aware self-attention mechanism, which improves performance on tasks involving reasoning about entities.
as an entity-aware self-attention mechanism, which improves performance on tasks involving reasoning about entities.
- LUKE adds :obj:`entity_ids`, :obj:`entity_attention_mask`, :obj:`entity_token_type_ids` and
:obj:`entity_position_ids` as extra input to the model. You can obtain those using :class:`LukeTokenizer`.
:obj:`entity_position_ids` as extra input to the model. Input entities can be special entities (e.g., [MASK]) or
Wikipedia entities (e.g., New York City). You can obtain those using :class:`~transformers.LukeTokenizer`.

Example:

.. code-block::

>>> from transformers import LukeTokenizer, LukeModel

>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base")
>>> model = LukeModel.from_pretrained("studio-ousia/luke-base")

# Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"
>>> text = "Beyoncé lives in New York."
>>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
>>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state

# Input Wikipedia entities to obtain enriched contextualized representations.
>>> text = "Beyoncé lives in New York."
>>> entities = ["Beyoncé", "New York City"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "New York"
>>> entity_spans = [(0, 7), (17, 25)] # character-based entity spans corresponding to "Beyoncé" and "New York"
>>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state

The original code can be found `here <https://github.com/studio-ousia/luke>`_.
The original code can be found `here <https://github.com/studio-ousia/luke>`__.


LukeConfig
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):

List options


Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:


- A string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
Expand Down Expand Up @@ -391,7 +389,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the ``return_unused_kwargs`` keyword parameter.


Examples::

>>> from transformers import AutoConfig
Expand Down
82 changes: 71 additions & 11 deletions src/transformers/models/luke/configuration_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" LUKE configuration """

from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig
from ...configuration_utils import PretrainedConfig


logger = logging.get_logger(__name__)
Expand All @@ -26,49 +26,109 @@
}


class LukeConfig(RobertaConfig):
class LukeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.LukeModel`. It is used to
instantiate a LUKE model according to the specified arguments, defining the model architecture. Configuration
objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the
documentation from :class:`~transformers.PretrainedConfig` for more information. The
:class:`~transformers.LukeConfig` class directly inherits :class:`~transformers.RobertaConfig`. It reuses the same
defaults. Please check the parent class for more information.
instantiate a LUKE model according to the specified arguments, defining the model architecture.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.


Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
entity_vocab_size (:obj:`int`, `optional`, defaults to 500000):
Entity vocabulary size of the LUKE model. Defines the number of different entities that can be represented
by the :obj:`entity_ids` passed when calling :class:`~transformers.LukeModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
entity_emb_size (:obj:`int`, `optional`, defaults to 256):
The number of dimensions of the entity embedding.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`):
Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.)
<https://arxiv.org/abs/2010.01057>`__.

Examples::
>>> from transformers import LukeConfig, LukeModel

>>> # Initializing a LUKE configuration
>>> configuration = LukeConfig()

>>> # Initializing a model from the configuration
>>> model = LukeModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "luke"

def __init__(
self,
vocab_size: int = 50267,
entity_vocab_size: int = 500000,
entity_emb_size: int = 256,
vocab_size=50267,
entity_vocab_size=500000,
hidden_size=768,
entity_emb_size=256,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
gradient_checkpointing=False,
use_entity_aware_attention=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**kwargs
):
"""Constructs LukeConfig."""
super(LukeConfig, self).__init__(vocab_size=vocab_size, **kwargs)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

self.vocab_size = vocab_size
self.entity_vocab_size = entity_vocab_size
self.hidden_size = hidden_size
self.entity_emb_size = entity_emb_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.use_entity_aware_attention = use_entity_aware_attention
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2]))
config.vocab_size += 2

print("Saving tokenizer to {}".format(pytorch_dump_folder_path))
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
tokenizer.save_pretrained(pytorch_dump_folder_path)
with open(os.path.join(pytorch_dump_folder_path, LukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f:
json.dump(entity_vocab, f)
Expand All @@ -61,7 +61,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
# Initialize the query layers of the entity-aware self-attention mechanism
for layer_index in range(config.num_hidden_layers):
for matrix_name in ["query.weight", "query.bias"]:
prefix = "encoder.layer." + str(layer_index) + ".attention.self."
prefix = f"encoder.layer.{layer_index}.attention.self."
state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name]
Expand Down
48 changes: 33 additions & 15 deletions src/transformers/models/luke/modeling_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):

class LukeEntityEmbeddings(nn.Module):
def __init__(self, config: LukeConfig):
super(LukeEntityEmbeddings, self).__init__()
super().__init__()
self.config = config

self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0)
Expand Down Expand Up @@ -366,21 +366,26 @@ def forward(
value_layer = self.transpose_for_scores(self.value(concat_hidden_states))

if self.use_entity_aware_attention and entity_hidden_states is not None:
# compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e)
# query layers
w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states))
w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states))
e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states))
e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states))

# compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above
w2w_key_layer = key_layer[:, :, :word_size, :]
e2w_key_layer = key_layer[:, :, :word_size, :]
w2e_key_layer = key_layer[:, :, word_size:, :]
e2e_key_layer = key_layer[:, :, word_size:, :]

# compute attention scores based on the dot product between the query and key vectors
w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2))
w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2))
e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2))
e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2))

# combine attention scores to create the final attention score matrix
word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3)
entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3)
attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2)
Expand Down Expand Up @@ -825,7 +830,7 @@ def set_entity_embeddings(self, value):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError("LUKE does not support the pruning of attention heads")

@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BaseLukeModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
Expand Down Expand Up @@ -860,7 +865,7 @@ def forward(

>>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.word_last_hidden_state
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state

# Input Wikipedia entities to obtain enriched contextualized representations.
Expand All @@ -870,7 +875,7 @@ def forward(

>>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.word_last_hidden_state
>>> word_last_hidden_state = outputs.last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -919,7 +924,7 @@ def forward(
)

# Second, compute extended attention mask
extended_attention_mask = self._compute_extended_attention_mask(attention_mask, entity_attention_mask)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask)

# Third, compute entity embeddings and concatenate with word embeddings
if entity_ids is None:
Expand All @@ -945,10 +950,7 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseLukeModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand All @@ -959,9 +961,19 @@ def forward(
entity_hidden_states=encoder_outputs.entity_hidden_states,
)

def _compute_extended_attention_mask(
self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor]
):
def get_extended_attention_mask(self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor]):
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

Arguments:
word_attention_mask (:obj:`torch.LongTensor`):
Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
entity_attention_mask (:obj:`torch.LongTensor`, `optional`):
Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore.

Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
attention_mask = word_attention_mask
if entity_attention_mask is not None:
attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1)
Expand All @@ -971,7 +983,7 @@ def _compute_extended_attention_mask(
elif attention_mask.dim() == 2:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError("Wrong shape for attention_mask (shape {})".format(attention_mask.shape))
raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")

extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
Expand Down Expand Up @@ -1079,6 +1091,8 @@ def forward(

loss = None
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 1:
loss = F.cross_entropy(logits, labels)
else:
Expand Down Expand Up @@ -1186,6 +1200,8 @@ def forward(

loss = None
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 1:
loss = F.cross_entropy(logits, labels)
else:
Expand Down Expand Up @@ -1250,10 +1266,10 @@ def forward(
return_dict=None,
):
r"""
entity_start_positions:
entity_start_positions (:obj:`torch.LongTensor`):
The start positions of entities in the word token sequence.

entity_end_positions:
entity_end_positions (:obj:`torch.LongTensor`):
The end positions of entities in the word token sequence.

labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)` or :obj:`(batch_size, entity_length, num_labels), `optional`):
Expand Down Expand Up @@ -1308,6 +1324,8 @@ def forward(

loss = None
if labels is not None:
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 2:
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
else:
Expand Down
Loading