diff --git a/docs/source/model_doc/luke.rst b/docs/source/model_doc/luke.rst index b3190ea6532d..5b2b424f5ff6 100644 --- a/docs/source/model_doc/luke.rst +++ b/docs/source/model_doc/luke.rst @@ -1,4 +1,4 @@ -.. +.. Copyright 2021 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with @@ -38,13 +38,39 @@ answering).* Tips: - - This implementation is the same as :class:`~transformers.RobertaModel` with the addition of entity embeddings as well - as an entity- aware self-attention mechanism, which improves performance on tasks involving reasoning about entities. + as an entity-aware self-attention mechanism, which improves performance on tasks involving reasoning about entities. - LUKE adds :obj:`entity_ids`, :obj:`entity_attention_mask`, :obj:`entity_token_type_ids` and - :obj:`entity_position_ids` as extra input to the model. You can obtain those using :class:`LukeTokenizer`. + :obj:`entity_position_ids` as extra input to the model. Input entities can be special entities (e.g., [MASK]) or + Wikipedia entities (e.g., New York City). You can obtain those using :class:`~transformers.LukeTokenizer`. + +Example: + +.. code-block:: + + >>> from transformers import LukeTokenizer, LukeModel + + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base") + >>> model = LukeModel.from_pretrained("studio-ousia/luke-base") + + # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé" + >>> text = "Beyoncé lives in New York." + >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé" + >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**encoding) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + + # Input Wikipedia entities to obtain enriched contextualized representations. + >>> text = "Beyoncé lives in New York." + >>> entities = ["Beyoncé", "New York City"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "New York" + >>> entity_spans = [(0, 7), (17, 25)] # character-based entity spans corresponding to "Beyoncé" and "New York" + >>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**encoding) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state -The original code can be found `here `_. +The original code can be found `here `__. LukeConfig diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8d8faa066806..45f7838f1231 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -350,12 +350,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): List options - Args: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): Can be either: - - A string, the `model id` of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. @@ -391,7 +389,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the ``return_unused_kwargs`` keyword parameter. - Examples:: >>> from transformers import AutoConfig diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py index 1a8ab38ea28b..4862b336d056 100644 --- a/src/transformers/models/luke/configuration_luke.py +++ b/src/transformers/models/luke/configuration_luke.py @@ -15,7 +15,7 @@ """ LUKE configuration """ from ...utils import logging -from ..roberta.configuration_roberta import RobertaConfig +from ...configuration_utils import PretrainedConfig logger = logging.get_logger(__name__) @@ -26,22 +26,51 @@ } -class LukeConfig(RobertaConfig): +class LukeConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a :class:`~transformers.LukeModel`. It is used to - instantiate a LUKE model according to the specified arguments, defining the model architecture. Configuration - objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the - documentation from :class:`~transformers.PretrainedConfig` for more information. The - :class:`~transformers.LukeConfig` class directly inherits :class:`~transformers.RobertaConfig`. It reuses the same - defaults. Please check the parent class for more information. + instantiate a LUKE model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. Args: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or + :class:`~transformers.TFBertModel`. entity_vocab_size (:obj:`int`, `optional`, defaults to 500000): Entity vocabulary size of the LUKE model. Defines the number of different entities that can be represented by the :obj:`entity_ids` passed when calling :class:`~transformers.LukeModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. entity_emb_size (:obj:`int`, `optional`, defaults to 256): The number of dimensions of the entity embedding. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or + :class:`~transformers.TFBertModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`): Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.) @@ -49,10 +78,13 @@ class LukeConfig(RobertaConfig): Examples:: >>> from transformers import LukeConfig, LukeModel + >>> # Initializing a LUKE configuration >>> configuration = LukeConfig() + >>> # Initializing a model from the configuration >>> model = LukeModel(configuration) + >>> # Accessing the model configuration >>> configuration = model.config """ @@ -60,15 +92,43 @@ class LukeConfig(RobertaConfig): def __init__( self, - vocab_size: int = 50267, - entity_vocab_size: int = 500000, - entity_emb_size: int = 256, + vocab_size=50267, + entity_vocab_size=500000, + hidden_size=768, + entity_emb_size=256, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + gradient_checkpointing=False, use_entity_aware_attention=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, **kwargs ): """Constructs LukeConfig.""" - super(LukeConfig, self).__init__(vocab_size=vocab_size, **kwargs) + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.vocab_size = vocab_size self.entity_vocab_size = entity_vocab_size + self.hidden_size = hidden_size self.entity_emb_size = entity_emb_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.gradient_checkpointing = gradient_checkpointing self.use_entity_aware_attention = use_entity_aware_attention diff --git a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py index d42ce4ec848a..55e2aab4130b 100644 --- a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py @@ -45,7 +45,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2])) config.vocab_size += 2 - print("Saving tokenizer to {}".format(pytorch_dump_folder_path)) + print(f"Saving tokenizer to {pytorch_dump_folder_path}") tokenizer.save_pretrained(pytorch_dump_folder_path) with open(os.path.join(pytorch_dump_folder_path, LukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f: json.dump(entity_vocab, f) @@ -61,7 +61,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p # Initialize the query layers of the entity-aware self-attention mechanism for layer_index in range(config.num_hidden_layers): for matrix_name in ["query.weight", "query.bias"]: - prefix = "encoder.layer." + str(layer_index) + ".attention.self." + prefix = f"encoder.layer.{layer_index}.attention.self." state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name] state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name] state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name] diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 15b43b3a729b..805d3ba46465 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -279,7 +279,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): class LukeEntityEmbeddings(nn.Module): def __init__(self, config: LukeConfig): - super(LukeEntityEmbeddings, self).__init__() + super().__init__() self.config = config self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0) @@ -366,21 +366,26 @@ def forward( value_layer = self.transpose_for_scores(self.value(concat_hidden_states)) if self.use_entity_aware_attention and entity_hidden_states is not None: + # compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e) + # query layers w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states)) w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states)) e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states)) e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states)) + # compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above w2w_key_layer = key_layer[:, :, :word_size, :] e2w_key_layer = key_layer[:, :, :word_size, :] w2e_key_layer = key_layer[:, :, word_size:, :] e2e_key_layer = key_layer[:, :, word_size:, :] + # compute attention scores based on the dot product between the query and key vectors w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2)) w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2)) e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2)) e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2)) + # combine attention scores to create the final attention score matrix word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3) entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3) attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2) @@ -825,7 +830,7 @@ def set_entity_embeddings(self, value): def _prune_heads(self, heads_to_prune): raise NotImplementedError("LUKE does not support the pruning of attention heads") - @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BaseLukeModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -860,7 +865,7 @@ def forward( >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") >>> outputs = model(**encoding) - >>> word_last_hidden_state = outputs.word_last_hidden_state + >>> word_last_hidden_state = outputs.last_hidden_state >>> entity_last_hidden_state = outputs.entity_last_hidden_state # Input Wikipedia entities to obtain enriched contextualized representations. @@ -870,7 +875,7 @@ def forward( >>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") >>> outputs = model(**encoding) - >>> word_last_hidden_state = outputs.word_last_hidden_state + >>> word_last_hidden_state = outputs.last_hidden_state >>> entity_last_hidden_state = outputs.entity_last_hidden_state """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -919,7 +924,7 @@ def forward( ) # Second, compute extended attention mask - extended_attention_mask = self._compute_extended_attention_mask(attention_mask, entity_attention_mask) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask) # Third, compute entity embeddings and concatenate with word embeddings if entity_ids is None: @@ -945,10 +950,7 @@ def forward( pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return ( - sequence_output, - pooled_output, - ) + encoder_outputs[1:] + return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseLukeModelOutputWithPooling( last_hidden_state=sequence_output, @@ -959,9 +961,19 @@ def forward( entity_hidden_states=encoder_outputs.entity_hidden_states, ) - def _compute_extended_attention_mask( - self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor] - ): + def get_extended_attention_mask(self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor]): + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + word_attention_mask (:obj:`torch.LongTensor`): + Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore. + entity_attention_mask (:obj:`torch.LongTensor`, `optional`): + Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ attention_mask = word_attention_mask if entity_attention_mask is not None: attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1) @@ -971,7 +983,7 @@ def _compute_extended_attention_mask( elif attention_mask.dim() == 2: extended_attention_mask = attention_mask[:, None, None, :] else: - raise ValueError("Wrong shape for attention_mask (shape {})".format(attention_mask.shape)) + raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})") extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 @@ -1079,6 +1091,8 @@ def forward( loss = None if labels is not None: + # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. if labels.ndim == 1: loss = F.cross_entropy(logits, labels) else: @@ -1186,6 +1200,8 @@ def forward( loss = None if labels is not None: + # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. if labels.ndim == 1: loss = F.cross_entropy(logits, labels) else: @@ -1250,10 +1266,10 @@ def forward( return_dict=None, ): r""" - entity_start_positions: + entity_start_positions (:obj:`torch.LongTensor`): The start positions of entities in the word token sequence. - entity_end_positions: + entity_end_positions (:obj:`torch.LongTensor`): The end positions of entities in the word token sequence. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)` or :obj:`(batch_size, entity_length, num_labels), `optional`): @@ -1308,6 +1324,8 @@ def forward( loss = None if labels is not None: + # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. if labels.ndim == 2: loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) else: diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index eb8a085df230..9c5620a6a8bf 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -25,7 +25,6 @@ from ...file_utils import add_end_docstrings, is_tf_available, is_torch_available from ...tokenization_utils_base import ( ENCODE_KWARGS_DOCSTRING, - ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, AddedToken, BatchEncoding, EncodedInput, @@ -43,6 +42,11 @@ logger = logging.get_logger(__name__) +EntitySpan = Tuple[int, int] +EntitySpanInput = List[EntitySpan] +Entity = str +EntityInput = List[Entity] + VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "merges_file": "merges.txt", @@ -69,6 +73,76 @@ "studio-ousia/luke-large": 512, } +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (:obj:`bool`, `optional`): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + return_attention_mask (:obj:`bool`, `optional`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are attention masks? <../glossary.html#attention-mask>`__ + return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return overflowing token sequences. + return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return :obj:`(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from + :class:`~transformers.PreTrainedTokenizerFast`, if using Python's tokenizer, this method will raise + :obj:`NotImplementedError`. + return_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return the lengths of the encoded inputs. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. + **kwargs: passed to the :obj:`self.tokenize()` method + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + + - **token_type_ids** -- List of token type ids to be fed to a model (when :obj:`return_token_type_ids=True` + or if `"token_type_ids"` is in :obj:`self.model_input_names`). + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + - **entity_ids** -- List of entity ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + + - **entity_position_ids** -- List of entity positions in the input sequence to be fed to a model. + + - **entity_token_type_ids** -- List of entity token type ids to be fed to a model (when + :obj:`return_token_type_ids=True` or if `"entity_token_type_ids"` is in :obj:`self.model_input_names`). + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + + - **entity_attention_mask** -- List of indices specifying which entities should be attended to by the model + (when :obj:`return_attention_mask=True` or if `"entity_attention_mask"` is in + :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + - **overflowing_tokens** -- List of overflowing tokens sequences (when a :obj:`max_length` is specified and + :obj:`return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a :obj:`max_length` is specified and + :obj:`return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when :obj:`add_special_tokens=True` and :obj:`return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when :obj:`return_length=True`) +""" + class LukeTokenizer(RobertaTokenizer): r""" @@ -88,19 +162,19 @@ class LukeTokenizer(RobertaTokenizer): entity_vocab_file (:obj:`str`): Path to the entity vocabulary file. task (:obj:`str`, `optional`): - Task for which you want to prepare sequences. One of "entity_classification" or - "entity_pair_classification". If you specify this argument, the entity sequence is automatically created - based on the given entity spans. + Task for which you want to prepare sequences. One of obj:`"entity_classification"` or + obj:`"entity_pair_classification"`. If you specify this argument, the entity sequence is automatically + created based on the given entity spans. max_entity_length (:obj:`int`, `optional`, defaults to 32): The maximum length of :obj:`entity_ids`. max_mention_length (:obj:`int`, `optional`, defaults to 30): The maximum number of tokens inside an entity span. entity_token_1 (:obj:`str`, `optional`, defaults to :obj:``): The special token representing an entity span. This token is only used when ``task`` is set to - "entity_classification" or "entity_pair_classification". + obj:`"entity_classification"` or obj:`"entity_pair_classification"`. entity_token_2 (:obj:`str`, `optional`, defaults to :obj:``): The special token representing an entity span. This token is only used when ``task`` is set to - "entity_pair_classification". + obj:`"entity_pair_classification"`. """ vocab_files_names = VOCAB_FILES_NAMES @@ -131,7 +205,9 @@ def __init__( if isinstance(entity_token_2, str) else entity_token_2 ) - kwargs["additional_special_tokens"] = [entity_token_1, entity_token_2] + kwargs.get("additional_special_tokens", []) + kwargs["additional_special_tokens"] = [entity_token_1, entity_token_2] + kwargs.get( + "additional_special_tokens", [] + ) super().__init__( vocab_file=vocab_file, @@ -155,7 +231,9 @@ def __init__( elif task == "entity_pair_classification": self.max_entity_length = 2 else: - raise ValueError(f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification'] only.") + raise ValueError( + f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification'] only." + ) self.max_mention_length = max_mention_length @@ -164,10 +242,10 @@ def __call__( self, text: Union[TextInput, List[TextInput]], text_pair: Optional[Union[TextInput, List[TextInput]]] = None, - entity_spans: Optional[Union[List[Tuple[int, int]], List[List[Tuple[int, int]]]]] = None, - entity_spans_pair: Optional[Union[List[Tuple[int, int]], List[List[Tuple[int, int]]]]] = None, - entities: Optional[Union[List[str], List[List[str]]]] = None, - entities_pair: Optional[Union[List[str], List[List[str]]]] = None, + entity_spans: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None, + entity_spans_pair: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None, + entities: Optional[Union[EntityInput, List[EntityInput]]] = None, + entities_pair: Optional[Union[EntityInput, List[EntityInput]]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, @@ -197,43 +275,47 @@ def __call__( text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this tokenizer does not support tokenization based on pretokenized strings. - entity_spans (:obj:`List[Tuple]`, :obj:`List[List[Tuple]]`, `optional`): + entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`): The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each - with two integers denoting start and end positions of entities. If you specify "entity_classification" - or "entity_pair_classification" as the `task` argument in the constructor, the length of each sequence - must be 1 or 2, respectively. If you specify `entities`, the length of each sequence must be equal to - the length of each sequence of `entities`. - entity_spans_pair (:obj:`List[Tuple]`, :obj:`List[List[Tuple]]`, `optional`): + with two integers denoting start and end positions of entities. If you specify + obj:`"entity_classification"` or obj:`"entity_pair_classification"` as the ``task`` argument in the + constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the + length of each sequence must be equal to the length of each sequence of ``entities``. + entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`): The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each - with two integers denoting start and end positions of entities. If you specify the `task` argument in - the constructor, this argument is ignored. If you specify `entities_pair`, the length of each sequence - must be equal to the length of each sequence of `entities_pair`. + with two integers denoting start and end positions of entities. If you specify the ``task`` argument in + the constructor, this argument is ignored. If you specify ``entities_pair``, the length of each + sequence must be equal to the length of each sequence of ``entities_pair``. entities (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New - York). This argument is ignored if you specify the `task` argument in the constructor. The length of - each sequence must be equal to the length of each sequence of `entity_spans`. If you specify - `entity_spans` without specifying this argument, the entity sequence or the batch of entity sequences + York). This argument is ignored if you specify the ``task`` argument in the constructor. The length of + each sequence must be equal to the length of each sequence of ``entity_spans``. If you specify + ``entity_spans`` without specifying this argument, the entity sequence or the batch of entity sequences is automatically constructed by filling it with the [MASK] special entities. entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New - York). This argument is ignored if you specify the `task` argument in the constructor. The length of - each sequence must be equal to the length of each sequence of `entity_spans_pair`. If you specify - `entity_spans_pair` without specifying this argument, the entity sequence or the batch of entity + York). This argument is ignored if you specify the ``task`` argument in the constructor. The length of + each sequence must be equal to the length of each sequence of ``entity_spans_pair``. If you specify + ``entity_spans_pair`` without specifying this argument, the entity sequence or the batch of entity sequences is automatically constructed by filling it with the [MASK] special entities. max_entity_length (:obj:`int`, `optional`): The maximum length of :obj:`entity_ids`. """ # Input type checking for clearer error - assert isinstance(text, str) or ( - isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) + is_valid_single_text = isinstance(text, str) + is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) + assert ( + is_valid_single_text or is_valid_batch_text ), "text input must be of type `str` (single example) or `List[str]` (batch)." + is_valid_single_text_pair = isinstance(text_pair, str) + is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and ( + len(text_pair) == 0 or isinstance(text_pair[0], str) + ) assert ( - text_pair is None - or isinstance(text_pair, str) - or (isinstance(text_pair, (list, tuple)) and (len(text_pair) == 0 or isinstance(text_pair[0], str))) + text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair ), "text_pair input must be of type `str` (single example) or `List[str]` (batch)." is_batched = bool(isinstance(text, (list, tuple))) @@ -308,10 +390,10 @@ def encode_plus( self, text: Union[TextInput], text_pair: Optional[Union[TextInput]] = None, - entity_spans: Optional[List[Tuple[int, int]]] = None, - entity_spans_pair: Optional[List[Tuple[int, int]]] = None, - entities: Optional[List[str]] = None, - entities_pair: Optional[List[str]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, @@ -331,8 +413,8 @@ def encode_plus( **kwargs ) -> BatchEncoding: """ - Tokenize and prepare for the model a sequence or a pair of sequences. - + Tokenize and prepare for the model a sequence or a pair of sequences. + .. warning:: This method is deprecated, ``__call__`` should be used instead. Args: @@ -340,29 +422,30 @@ def encode_plus( The first sequence to be encoded. Each sequence must be a string. text_pair (:obj:`str`): The second sequence to be encoded. Each sequence must be a string. - entity_spans (:obj:`List[Tuple]`, obj:`List[List[Tuple]]`, `optional`):: + entity_spans (:obj:`List[Tuple[int, int]]`, obj:`List[List[Tuple[int, int]]]`, `optional`):: The first sequence of entity spans to be encoded. The sequence consists of tuples each with two - integers denoting start and end positions of entities. If you specify "entity_classification" or - "entity_pair_classification" as the `task` argument in the constructor, the length of each sequence - must be 1 or 2, respectively. If you specify `entities`, the length of the sequence must be equal to - the length of `entities`. - entity_spans_pair (:obj:`List[Tuple]`, obj:`List[List[Tuple]]`, `optional`):: + integers denoting start and end positions of entities. If you specify obj:`"entity_classification"` or + obj:`"entity_pair_classification"` as the ``task`` argument in the constructor, the length of each + sequence must be 1 or 2, respectively. If you specify ``entities``, the length of the sequence must be + equal to the length of ``entities``. + entity_spans_pair (:obj:`List[Tuple[int, int]]`, obj:`List[List[Tuple[int, int]]]`, `optional`):: The second sequence of entity spans to be encoded. The sequence consists of tuples each with two - integers denoting start and end positions of entities. If you specify the `task` argument in the - constructor, this argument is ignored. If you specify `entities_pair`, the length of the sequence must - be equal to the length of `entities_pair`. + integers denoting start and end positions of entities. If you specify the ``task`` argument in the + constructor, this argument is ignored. If you specify ``entities_pair``, the length of the sequence + must be equal to the length of ``entities_pair``. entities (:obj:`List[str]` `optional`):: The first sequence of entities to be encoded. The sequence consists of strings representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New York). This argument is - ignored if you specify the `task` argument in the constructor. The length of the sequence must be equal - to the length of `entity_spans`. If you specify `entity_spans` without specifying this argument, the - entity sequence is automatically constructed by filling it with the [MASK] special entities. + ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be + equal to the length of ``entity_spans``. If you specify ``entity_spans`` without specifying this + argument, the entity sequence is automatically constructed by filling it with the [MASK] special + entities. entities_pair (:obj:`List[str]`, obj:`List[List[str]]`, `optional`):: The second sequence of entities to be encoded. The sequence consists of strings representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New York). This argument is - ignored if you specify the `task` argument in the constructor. The length of the sequence must be equal - to the length of `entity_spans_pair`. If you specify `entity_spans_pair` without specifying this - argument, the entity sequence is automatically constructed by filling it with the [MASK] special + ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be + equal to the length of ``entity_spans_pair``. If you specify ``entity_spans_pair`` without specifying + this argument, the entity sequence is automatically constructed by filling it with the [MASK] special entities. max_entity_length (:obj:`int`, `optional`): The maximum length of the entity sequence. @@ -407,10 +490,10 @@ def _encode_plus( self, text: Union[TextInput], text_pair: Optional[Union[TextInput]] = None, - entity_spans: Optional[List[Tuple[int, int]]] = None, - entity_spans_pair: Optional[List[Tuple[int, int]]] = None, - entities: Optional[List[str]] = None, - entities_pair: Optional[List[str]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, @@ -489,9 +572,11 @@ def batch_encode_plus( self, batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], batch_entity_spans_or_entity_spans_pairs: Optional[ - Union[List[List[Tuple[int, int]]], List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]] + Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]] ] = None, - batch_entities_or_entities_pairs: Optional[Union[List[List[str]], List[Tuple[List[str], List[str]]]]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, @@ -569,9 +654,11 @@ def _batch_encode_plus( self, batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], batch_entity_spans_or_entity_spans_pairs: Optional[ - Union[List[List[Tuple[int, int]]], List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]] + Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]] ] = None, - batch_entities_or_entities_pairs: Optional[Union[List[List[str]], List[Tuple[List[str], List[str]]]]] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, @@ -674,10 +761,10 @@ def _create_input_sequence( self, text: Union[TextInput], text_pair: Optional[Union[TextInput]] = None, - entities: Optional[List[str]] = None, - entities_pair: Optional[List[str]] = None, - entity_spans: Optional[List[Tuple[int, int]]] = None, - entity_spans_pair: Optional[List[Tuple[int, int]]] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, **kwargs ) -> Tuple[list, list, list, list, list, list]: def get_input_ids(text): @@ -1055,8 +1142,7 @@ def prepare_for_model( if num_invalid_entities != 0: logger.warning( - "%d entities are ignored because their entity spans are invalid due to the truncation of input tokens", - num_invalid_entities, + f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens" ) if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length: diff --git a/tests/test_tokenization_luke.py b/tests/test_tokenization_luke.py index 1146cb4f3299..32e2abf076e9 100644 --- a/tests/test_tokenization_luke.py +++ b/tests/test_tokenization_luke.py @@ -190,107 +190,16 @@ def test_single_text_no_padding_or_truncation(self): ) self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off self.assertEqual( encoding["entity_position_ids"], [ - [ - 3, - 4, - 5, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 9, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - ], + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] ) + # fmt: on def test_single_text_only_entity_spans_no_padding_or_truncation(self): tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) @@ -315,107 +224,16 @@ def test_single_text_only_entity_spans_no_padding_or_truncation(self): self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id]) self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off self.assertEqual( encoding["entity_position_ids"], [ - [ - 3, - 4, - 5, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 9, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - ], + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], + [9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ] + ] ) + # fmt: on def test_single_text_padding_pytorch_tensors(self): tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) @@ -486,107 +304,16 @@ def test_text_pair_no_padding_or_truncation(self): ) self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off self.assertEqual( encoding["entity_position_ids"], [ - [ - 3, - 4, - 5, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 11, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - ], + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] ) + # fmt: on def test_text_pair_only_entity_spans_no_padding_or_truncation(self): tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) @@ -619,107 +346,16 @@ def test_text_pair_only_entity_spans_no_padding_or_truncation(self): self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id]) self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off self.assertEqual( encoding["entity_position_ids"], [ - [ - 3, - 4, - 5, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 11, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - ], + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] ) + # fmt: on def test_text_pair_padding_pytorch_tensors(self): tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) @@ -778,43 +414,14 @@ def test_entity_classification_no_padding_or_truncation(self): self.assertEqual(encoding["entity_ids"], [2]) self.assertEqual(encoding["entity_attention_mask"], [1]) self.assertEqual(encoding["entity_token_type_ids"], [0]) + # fmt: off self.assertEqual( encoding["entity_position_ids"], [ - [ - 9, - 10, - 11, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ], + [9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] ) + # fmt: on def test_entity_classification_padding_pytorch_tensors(self): tokenizer = LukeTokenizer.from_pretrained( @@ -866,75 +473,15 @@ def test_entity_pair_classification_no_padding_or_truncation(self): self.assertEqual(encoding["entity_ids"], [2, 3]) self.assertEqual(encoding["entity_attention_mask"], [1, 1]) self.assertEqual(encoding["entity_token_type_ids"], [0, 0]) + # fmt: off self.assertEqual( encoding["entity_position_ids"], [ - [ - 3, - 4, - 5, - 6, - 7, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - [ - 11, - 12, - 13, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ], - ], + [3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] ) + # fmt: on def test_entity_pair_classification_padding_pytorch_tensors(self): tokenizer = LukeTokenizer.from_pretrained(