diff --git a/README.md b/README.md index 1e7ced945aa84a..79d236f8bbfa77 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/transformers/model_doc/longformer.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. +1. **[LUKE](https://huggingface.co/transformers/model_doc/luke.html)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. 1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](https://huggingface.co/transformers/model_doc/m2m_100.html)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. diff --git a/docs/source/community.md b/docs/source/community.md index e1b467863df15e..8ac15f4c889468 100644 --- a/docs/source/community.md +++ b/docs/source/community.md @@ -52,3 +52,6 @@ This page regroups resources around 🤗 Transformers developed by the community |[Fine-tune BART for summarization in two languages with Trainer class](https://github.com/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb) | How to fine-tune BART for summarization in two languages with Trainer class | [Eliza Szczechla](https://github.com/elsanns) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb)| |[Evaluate Big Bird on Trivia QA](https://github.com/patrickvonplaten/notebooks/blob/master/Evaluating_Big_Bird_on_TriviaQA.ipynb) | How to evaluate BigBird on long document question answering on Trivia QA | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Evaluating_Big_Bird_on_TriviaQA.ipynb)| | [Create video captions using Wav2Vec2](https://github.com/Muennighoff/ytclipcc/blob/main/wav2vec_youtube_captions.ipynb) | How to create YouTube captions from any video by transcribing the audio with Wav2Vec | [Niklas Muennighoff](https://github.com/Muennighoff) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Muennighoff/ytclipcc/blob/main/wav2vec_youtube_captions.ipynb) | +| [Evaluate LUKE on Open Entity, an entity typing dataset](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb) | How to evaluate *LukeForEntityClassification* on the Open Entity dataset | [Ikuya Yamada](https://github.com/ikuyamada) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb) | +| [Evaluate LUKE on TACRED, a relation extraction dataset](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_tacred.ipynb) | How to evaluate *LukeForEntityPairClassification* on the TACRED dataset | [Ikuya Yamada](https://github.com/ikuyamada) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_tacred.ipynb) | +| [Evaluate LUKE on CoNLL-2003, an important NER benchmark](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_conll_2003.ipynb) | How to evaluate *LukeForEntitySpanClassification* on the CoNLL-2003 dataset | [Ikuya Yamada](https://github.com/ikuyamada) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_conll_2003.ipynb) | diff --git a/docs/source/index.rst b/docs/source/index.rst index 25a2a380431e7a..b83a5f8e49ff6e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -170,80 +170,83 @@ and conversion utilities for the following models: `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. 29. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -30. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +30. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity + Representations with Entity-aware Self-attention `__ by Ikuya Yamada, Akari Asai, + Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. +31. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -31. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +32. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -32. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +33. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -33. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +34. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -34. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +35. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -35. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +36. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -36. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +37. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -37. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +38. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -38. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +39. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -39. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +40. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -40. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +41. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -41. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +42. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -42. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +43. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -43. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +44. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -44. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +45. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -45. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +46. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -46. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +47. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -47. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +48. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -48. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +49. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -49. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +50. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -50. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +51. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -51. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +52. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -52. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +53. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -53. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +54. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -54. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +55. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -308,6 +311,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | LED | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | LayoutLM | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -468,6 +473,7 @@ TensorFlow and/or Flax. model_doc/layoutlm model_doc/led model_doc/longformer + model_doc/luke model_doc/lxmert model_doc/marian model_doc/m2m_100 diff --git a/docs/source/model_doc/luke.rst b/docs/source/model_doc/luke.rst new file mode 100644 index 00000000000000..34af117de98aa1 --- /dev/null +++ b/docs/source/model_doc/luke.rst @@ -0,0 +1,159 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +LUKE +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The LUKE model was proposed in `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention +`_ by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda and Yuji Matsumoto. +It is based on RoBERTa and adds entity embeddings as well as an entity-aware self-attention mechanism, which helps +improve performance on various downstream tasks involving reasoning about entities such as named entity recognition, +extractive and cloze-style question answering, entity typing, and relation classification. + +The abstract from the paper is the following: + +*Entity representations are useful in natural language tasks involving entities. In this paper, we propose new +pretrained contextualized representations of words and entities based on the bidirectional transformer. The proposed +model treats words and entities in a given text as independent tokens, and outputs contextualized representations of +them. Our model is trained using a new pretraining task based on the masked language model of BERT. The task involves +predicting randomly masked words and entities in a large entity-annotated corpus retrieved from Wikipedia. We also +propose an entity-aware self-attention mechanism that is an extension of the self-attention mechanism of the +transformer, and considers the types of tokens (words or entities) when computing attention scores. The proposed model +achieves impressive empirical performance on a wide range of entity-related tasks. In particular, it obtains +state-of-the-art results on five well-known datasets: Open Entity (entity typing), TACRED (relation classification), +CoNLL-2003 (named entity recognition), ReCoRD (cloze-style question answering), and SQuAD 1.1 (extractive question +answering).* + +Tips: + +- This implementation is the same as :class:`~transformers.RobertaModel` with the addition of entity embeddings as well + as an entity-aware self-attention mechanism, which improves performance on tasks involving reasoning about entities. +- LUKE treats entities as input tokens; therefore, it takes :obj:`entity_ids`, :obj:`entity_attention_mask`, + :obj:`entity_token_type_ids` and :obj:`entity_position_ids` as extra input. You can obtain those using + :class:`~transformers.LukeTokenizer`. +- :class:`~transformers.LukeTokenizer` takes :obj:`entities` and :obj:`entity_spans` (character-based start and end + positions of the entities in the input text) as extra input. :obj:`entities` typically consist of [MASK] entities or + Wikipedia entities. The brief description when inputting these entities are as follows: + + - *Inputting [MASK] entities to compute entity representations*: The [MASK] entity is used to mask entities to be + predicted during pretraining. When LUKE receives the [MASK] entity, it tries to predict the original entity by + gathering the information about the entity from the input text. Therefore, the [MASK] entity can be used to address + downstream tasks requiring the information of entities in text such as entity typing, relation classification, and + named entity recognition. + - *Inputting Wikipedia entities to compute knowledge-enhanced token representations*: LUKE learns rich information + (or knowledge) about Wikipedia entities during pretraining and stores the information in its entity embedding. By + using Wikipedia entities as input tokens, LUKE outputs token representations enriched by the information stored in + the embeddings of these entities. This is particularly effective for tasks requiring real-world knowledge, such as + question answering. + +- There are three head models for the former use case: + + - :class:`~transformers.LukeForEntityClassification`, for tasks to classify a single entity in an input text such as + entity typing, e.g. the `Open Entity dataset `__. + This model places a linear head on top of the output entity representation. + - :class:`~transformers.LukeForEntityPairClassification`, for tasks to classify the relationship between two entities + such as relation classification, e.g. the `TACRED dataset `__. This + model places a linear head on top of the concatenated output representation of the pair of given entities. + - :class:`~transformers.LukeForEntitySpanClassification`, for tasks to classify the sequence of entity spans, such as + named entity recognition (NER). This model places a linear head on top of the output entity representations. You + can address NER using this model by inputting all possible entity spans in the text to the model. + + :class:`~transformers.LukeTokenizer` has a ``task`` argument, which enables you to easily create an input to these + head models by specifying ``task="entity_classification"``, ``task="entity_pair_classification"``, or + ``task="entity_span_classification"``. Please refer to the example code of each head models. + + There are also 3 notebooks available, which showcase how you can reproduce the results as reported in the paper with + the HuggingFace implementation of LUKE. They can be found `here + `__. + +Example: + +.. code-block:: + + >>> from transformers import LukeTokenizer, LukeModel, LukeForEntityPairClassification + + >>> model = LukeModel.from_pretrained("studio-ousia/luke-base") + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base") + + # Example 1: Computing the contextualized entity representation corresponding to the entity mention "Beyoncé" + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé" + >>> inputs = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**inputs) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + + # Example 2: Inputting Wikipedia entities to obtain enriched contextualized representations + >>> entities = ["Beyoncé", "Los Angeles"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles" + >>> entity_spans = [(0, 7), (17, 28)] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles" + >>> inputs = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**inputs) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + + # Example 3: Classifying the relationship between two entities using LukeForEntityPairClassification head model + >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred") + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred") + >>> entity_spans = [(0, 7), (17, 28)] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles" + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_idx = int(logits[0].argmax()) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + +This model was contributed by `ikuyamada `__ and `nielsr +`__. The original code can be found `here `__. + + +LukeConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeConfig + :members: + + +LukeTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeTokenizer + :members: __call__, save_vocabulary + + +LukeModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeModel + :members: forward + + +LukeForEntityClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeForEntityClassification + :members: forward + + +LukeForEntityPairClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeForEntityPairClassification + :members: forward + + +LukeForEntitySpanClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeForEntitySpanClassification + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3e72488be2ab44..2a60516c803f9d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -189,6 +189,7 @@ "models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"], "models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"], "models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"], + "models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"], "models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], "models.marian": ["MarianConfig"], @@ -442,8 +443,8 @@ ] _import_structure["generation_utils"] = ["top_k_top_p_filtering"] _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] - # PyTorch models structure + # PyTorch models structure _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -751,6 +752,16 @@ "LongformerSelfAttention", ] ) + _import_structure["models.luke"].extend( + [ + "LUKE_PRETRAINED_MODEL_ARCHIVE_LIST", + "LukeForEntityClassification", + "LukeForEntityPairClassification", + "LukeForEntitySpanClassification", + "LukeModel", + "LukePreTrainedModel", + ] + ) _import_structure["models.lxmert"].extend( [ "LxmertEncoder", @@ -1540,6 +1551,7 @@ from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer + from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config from .models.marian import MarianConfig @@ -2020,6 +2032,14 @@ LongformerModel, LongformerSelfAttention, ) + from .models.luke import ( + LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, + LukeForEntityClassification, + LukeForEntityPairClassification, + LukeForEntitySpanClassification, + LukeModel, + LukePreTrainedModel, + ) from .models.lxmert import ( LxmertEncoder, LxmertForPreTraining, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 54f1e1021781da..b1ee27e7257a1b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -48,6 +48,7 @@ layoutlm, led, longformer, + luke, lxmert, m2m_100, marian, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 08003a90780432..f343348a7c7cd1 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -47,6 +47,7 @@ from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig +from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config from ..marian.configuration_marian import MarianConfig @@ -86,6 +87,7 @@ for pretrained_map in [ # Add archive maps here DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, + LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -138,6 +140,7 @@ [ # Add configs here ("deit", DeiTConfig), + ("luke", LukeConfig), ("gpt_neo", GPTNeoConfig), ("big_bird", BigBirdConfig), ("speech_to_text", Speech2TextConfig), @@ -196,6 +199,7 @@ [ # Add full (and cased) model names here ("deit", "DeiT"), + ("luke", "LUKE"), ("gpt_neo", "GPT Neo"), ("big_bird", "BigBird"), ("speech_to_text", "Speech2Text"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index eca851bbfd7834..22028d173bdf03 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -166,6 +166,7 @@ LongformerForTokenClassification, LongformerModel, ) +from ..luke.modeling_luke import LukeModel from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel @@ -308,6 +309,7 @@ LayoutLMConfig, LEDConfig, LongformerConfig, + LukeConfig, LxmertConfig, M2M100Config, MarianConfig, @@ -343,6 +345,7 @@ [ # Base model mapping (DeiTConfig, DeiTModel), + (LukeConfig, LukeModel), (GPTNeoConfig, GPTNeoModel), (BigBirdConfig, BigBirdModel), (Speech2TextConfig, Speech2TextModel), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 13089e21171c0e..7b13e64b68a2e1 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -41,6 +41,7 @@ from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer from ..led.tokenization_led import LEDTokenizer from ..longformer.tokenization_longformer import LongformerTokenizer +from ..luke.tokenization_luke import LukeTokenizer from ..lxmert.tokenization_lxmert import LxmertTokenizer from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer from ..mpnet.tokenization_mpnet import MPNetTokenizer @@ -81,6 +82,7 @@ LayoutLMConfig, LEDConfig, LongformerConfig, + LukeConfig, LxmertConfig, M2M100Config, MarianConfig, @@ -232,7 +234,6 @@ (MarianConfig, (MarianTokenizer, None)), (BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)), (BlenderbotConfig, (BlenderbotTokenizer, None)), - (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (BartConfig, (BartTokenizer, BartTokenizerFast)), (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), @@ -268,6 +269,7 @@ (IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)), (Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)), (GPTNeoConfig, (GPT2Tokenizer, GPT2TokenizerFast)), + (LukeConfig, (LukeTokenizer, None)), ] ) diff --git a/src/transformers/models/luke/__init__.py b/src/transformers/models/luke/__init__.py new file mode 100644 index 00000000000000..4f5f3155581ab6 --- /dev/null +++ b/src/transformers/models/luke/__init__.py @@ -0,0 +1,70 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available + + +_import_structure = { + "configuration_luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig"], + "tokenization_luke": ["LukeTokenizer"], +} + +if is_torch_available(): + _import_structure["modeling_luke"] = [ + "LUKE_PRETRAINED_MODEL_ARCHIVE_LIST", + "LukeForEntityClassification", + "LukeForEntityPairClassification", + "LukeForEntitySpanClassification", + "LukeModel", + "LukePreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig + from .tokenization_luke import LukeTokenizer + + if is_torch_available(): + from .modeling_luke import ( + LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, + LukeForEntityClassification, + LukeForEntityPairClassification, + LukeForEntitySpanClassification, + LukeModel, + LukePreTrainedModel, + ) + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py new file mode 100644 index 00000000000000..befd3e45e5de65 --- /dev/null +++ b/src/transformers/models/luke/configuration_luke.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright Studio Ousia and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" LUKE configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "studio-ousia/luke-base": "https://huggingface.co/studio-ousia/luke-base/resolve/main/config.json", + "studio-ousia/luke-large": "https://huggingface.co/studio-ousia/luke-large/resolve/main/config.json", +} + + +class LukeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.LukeModel`. It is used to + instantiate a LUKE model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the LUKE model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.LukeModel`. + entity_vocab_size (:obj:`int`, `optional`, defaults to 500000): + Entity vocabulary size of the LUKE model. Defines the number of different entities that can be represented + by the :obj:`entity_ids` passed when calling :class:`~transformers.LukeModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + entity_emb_size (:obj:`int`, `optional`, defaults to 256): + The number of dimensions of the entity embedding. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.LukeModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`): + Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep + Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.) + `__. + + Examples:: + + >>> from transformers import LukeConfig, LukeModel + + >>> # Initializing a LUKE configuration + >>> configuration = LukeConfig() + + >>> # Initializing a model from the configuration + >>> model = LukeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "luke" + + def __init__( + self, + vocab_size=50267, + entity_vocab_size=500000, + hidden_size=768, + entity_emb_size=256, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + gradient_checkpointing=False, + use_entity_aware_attention=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs + ): + """Constructs LukeConfig.""" + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.entity_vocab_size = entity_vocab_size + self.hidden_size = hidden_size + self.entity_emb_size = entity_emb_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.gradient_checkpointing = gradient_checkpointing + self.use_entity_aware_attention = use_entity_aware_attention diff --git a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000000..55e2aab4130ba0 --- /dev/null +++ b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,153 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert LUKE checkpoint.""" + +import argparse +import json +import os + +import torch + +from transformers import LukeConfig, LukeModel, LukeTokenizer, RobertaTokenizer +from transformers.tokenization_utils_base import AddedToken + + +@torch.no_grad() +def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size): + # Load configuration defined in the metadata file + with open(metadata_path) as metadata_file: + metadata = json.load(metadata_file) + config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"]) + + # Load in the weights from the checkpoint_path + state_dict = torch.load(checkpoint_path, map_location="cpu") + + # Load the entity vocab file + entity_vocab = load_entity_vocab(entity_vocab_path) + + tokenizer = RobertaTokenizer.from_pretrained(metadata["model_config"]["bert_model_name"]) + + # Add special tokens to the token vocabulary for downstream tasks + entity_token_1 = AddedToken("", lstrip=False, rstrip=False) + entity_token_2 = AddedToken("", lstrip=False, rstrip=False) + tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2])) + config.vocab_size += 2 + + print(f"Saving tokenizer to {pytorch_dump_folder_path}") + tokenizer.save_pretrained(pytorch_dump_folder_path) + with open(os.path.join(pytorch_dump_folder_path, LukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f: + json.dump(entity_vocab, f) + + tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path) + + # Initialize the embeddings of the special tokens + word_emb = state_dict["embeddings.word_embeddings.weight"] + ent_emb = word_emb[tokenizer.convert_tokens_to_ids(["@"])[0]].unsqueeze(0) + ent2_emb = word_emb[tokenizer.convert_tokens_to_ids(["#"])[0]].unsqueeze(0) + state_dict["embeddings.word_embeddings.weight"] = torch.cat([word_emb, ent_emb, ent2_emb]) + + # Initialize the query layers of the entity-aware self-attention mechanism + for layer_index in range(config.num_hidden_layers): + for matrix_name in ["query.weight", "query.bias"]: + prefix = f"encoder.layer.{layer_index}.attention.self." + state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name] + state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name] + state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name] + + # Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks + entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"] + entity_emb[entity_vocab["[MASK2]"]] = entity_emb[entity_vocab["[MASK]"]] + + model = LukeModel(config=config).eval() + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(missing_keys) == 1 and missing_keys[0] == "embeddings.position_ids" + assert all(key.startswith("entity_predictions") or key.startswith("lm_head") for key in unexpected_keys) + + # Check outputs + tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification") + + text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." + span = (39, 42) + encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt") + + outputs = model(**encoding) + + # Verify word hidden states + if model_size == "large": + expected_shape = torch.Size((1, 42, 1024)) + expected_slice = torch.tensor( + [[0.0133, 0.0865, 0.0095], [0.3093, -0.2576, -0.7418], [-0.1720, -0.2117, -0.2869]] + ) + else: # base + expected_shape = torch.Size((1, 42, 768)) + expected_slice = torch.tensor([[0.0037, 0.1368, -0.0091], [0.1099, 0.3329, -0.1095], [0.0765, 0.5335, 0.1179]]) + + assert outputs.last_hidden_state.shape == expected_shape + assert torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + + # Verify entity hidden states + if model_size == "large": + expected_shape = torch.Size((1, 1, 1024)) + expected_slice = torch.tensor([[0.0466, -0.0106, -0.0179]]) + else: # base + expected_shape = torch.Size((1, 1, 768)) + expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]]) + + assert outputs.entity_last_hidden_state.shape == expected_shape + assert torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + + # Finally, save our PyTorch model and tokenizer + print("Saving PyTorch model to {}".format(pytorch_dump_folder_path)) + model.save_pretrained(pytorch_dump_folder_path) + + +def load_entity_vocab(entity_vocab_path): + entity_vocab = {} + with open(entity_vocab_path, "r", encoding="utf-8") as f: + for (index, line) in enumerate(f): + title, _ = line.rstrip().split("\t") + entity_vocab[title] = index + + return entity_vocab + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--checkpoint_path", type=str, help="Path to a pytorch_model.bin file.") + parser.add_argument( + "--metadata_path", default=None, type=str, help="Path to a metadata.json file, defining the configuration." + ) + parser.add_argument( + "--entity_vocab_path", + default=None, + type=str, + help="Path to an entity_vocab.tsv file, containing the entity vocabulary.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to where to dump the output PyTorch model." + ) + parser.add_argument( + "--model_size", default="base", type=str, choices=["base", "large"], help="Size of the model to be converted." + ) + args = parser.parse_args() + convert_luke_checkpoint( + args.checkpoint_path, + args.metadata_path, + args.entity_vocab_path, + args.pytorch_dump_folder_path, + args.model_size, + ) diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py new file mode 100644 index 00000000000000..6db7bd62788aeb --- /dev/null +++ b/src/transformers/models/luke/modeling_luke.py @@ -0,0 +1,1367 @@ +# coding=utf-8 +# Copyright Studio Ousia and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LUKE model. """ + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward +from ...utils import logging +from .configuration_luke import LukeConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LukeConfig" +_TOKENIZER_FOR_DOC = "LukeTokenizer" + +LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "studio-ousia/luke-base", + "studio-ousia/luke-large", + # See all LUKE models at https://huggingface.co/models?filter=luke +] + + +@dataclass +class BaseLukeModelOutputWithPooling(BaseModelOutputWithPooling): + """ + Base class for outputs of the LUKE model. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + entity_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, entity_length, hidden_size)`): + Sequence of entity hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of + each layer plus the initial embedding outputs. + entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output + of each layer plus the initial entity embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length + entity_length, sequence_length + entity_length)`. Attentions weights after the attention + softmax, used to compute the weighted average in the self-attention heads. + """ + + entity_last_hidden_state: torch.FloatTensor = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BaseLukeModelOutput(BaseModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + entity_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, entity_length, hidden_size)`): + Sequence of entity hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output + of each layer plus the initial entity embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + entity_last_hidden_state: torch.FloatTensor = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class EntityClassificationOutput(ModelOutput): + """ + Outputs of entity classification models. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of + each layer plus the initial embedding outputs. + entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output + of each layer plus the initial entity embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class EntityPairClassificationOutput(ModelOutput): + """ + Outputs of entity pair classification models. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of + each layer plus the initial embedding outputs. + entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output + of each layer plus the initial entity embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class EntitySpanClassificationOutput(ModelOutput): + """ + Outputs of entity span classification models. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of + each layer plus the initial embedding outputs. + entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output + of each layer plus the initial entity embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class LukeEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class LukeEntityEmbeddings(nn.Module): + def __init__(self, config: LukeConfig): + super().__init__() + self.config = config + + self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0) + if config.entity_emb_size != config.hidden_size: + self.entity_embedding_dense = nn.Linear(config.entity_emb_size, config.hidden_size, bias=False) + + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, entity_ids: torch.LongTensor, position_ids: torch.LongTensor, token_type_ids: torch.LongTensor = None + ): + if token_type_ids is None: + token_type_ids = torch.zeros_like(entity_ids) + + entity_embeddings = self.entity_embeddings(entity_ids) + if self.config.entity_emb_size != self.config.hidden_size: + entity_embeddings = self.entity_embedding_dense(entity_embeddings) + + position_embeddings = self.position_embeddings(position_ids.clamp(min=0)) + position_embedding_mask = (position_ids != -1).type_as(position_embeddings).unsqueeze(-1) + position_embeddings = position_embeddings * position_embedding_mask + position_embeddings = torch.sum(position_embeddings, dim=-2) + position_embeddings = position_embeddings / position_embedding_mask.sum(dim=-2).clamp(min=1e-7) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = entity_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class LukeSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.use_entity_aware_attention = config.use_entity_aware_attention + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + if self.use_entity_aware_attention: + self.w2e_query = nn.Linear(config.hidden_size, self.all_head_size) + self.e2w_query = nn.Linear(config.hidden_size, self.all_head_size) + self.e2e_query = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + word_size = word_hidden_states.size(1) + + if entity_hidden_states is None: + concat_hidden_states = word_hidden_states + else: + concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + + key_layer = self.transpose_for_scores(self.key(concat_hidden_states)) + value_layer = self.transpose_for_scores(self.value(concat_hidden_states)) + + if self.use_entity_aware_attention and entity_hidden_states is not None: + # compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e) + # query layers + w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states)) + w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states)) + e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states)) + e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states)) + + # compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above + w2w_key_layer = key_layer[:, :, :word_size, :] + e2w_key_layer = key_layer[:, :, :word_size, :] + w2e_key_layer = key_layer[:, :, word_size:, :] + e2e_key_layer = key_layer[:, :, word_size:, :] + + # compute attention scores based on the dot product between the query and key vectors + w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2)) + w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2)) + e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2)) + e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2)) + + # combine attention scores to create the final attention score matrix + word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3) + entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3) + attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2) + + else: + query_layer = self.transpose_for_scores(self.query(concat_hidden_states)) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in LukeModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + output_word_hidden_states = context_layer[:, :word_size, :] + if entity_hidden_states is None: + output_entity_hidden_states = None + else: + output_entity_hidden_states = context_layer[:, word_size:, :] + + if output_attentions: + outputs = (output_word_hidden_states, output_entity_hidden_states, attention_probs) + else: + outputs = (output_word_hidden_states, output_entity_hidden_states) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class LukeSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LukeAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = LukeSelfAttention(config) + self.output = LukeSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + raise NotImplementedError("LUKE does not support the pruning of attention heads") + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + word_size = word_hidden_states.size(1) + self_outputs = self.self( + word_hidden_states, + entity_hidden_states, + attention_mask, + head_mask, + output_attentions, + ) + if entity_hidden_states is None: + concat_self_outputs = self_outputs[0] + concat_hidden_states = word_hidden_states + else: + concat_self_outputs = torch.cat(self_outputs[:2], dim=1) + concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + + attention_output = self.output(concat_self_outputs, concat_hidden_states) + + word_attention_output = attention_output[:, :word_size, :] + if entity_hidden_states is None: + entity_attention_output = None + else: + entity_attention_output = attention_output[:, word_size:, :] + + # add attentions if we output them + outputs = (word_attention_output, entity_attention_output) + self_outputs[2:] + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class LukeIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class LukeOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LukeLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = LukeAttention(config) + self.intermediate = LukeIntermediate(config) + self.output = LukeOutput(config) + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + word_size = word_hidden_states.size(1) + + self_attention_outputs = self.attention( + word_hidden_states, + entity_hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + if entity_hidden_states is None: + concat_attention_output = self_attention_outputs[0] + else: + concat_attention_output = torch.cat(self_attention_outputs[:2], dim=1) + + outputs = self_attention_outputs[2:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, concat_attention_output + ) + word_layer_output = layer_output[:, :word_size, :] + if entity_hidden_states is None: + entity_layer_output = None + else: + entity_layer_output = layer_output[:, word_size:, :] + + outputs = (word_layer_output, entity_layer_output) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class LukeEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + word_hidden_states, + entity_hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_word_hidden_states = () if output_hidden_states else None + all_entity_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_word_hidden_states = all_word_hidden_states + (word_hidden_states,) + all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + word_hidden_states, + entity_hidden_states, + attention_mask, + layer_head_mask, + ) + else: + layer_outputs = layer_module( + word_hidden_states, + entity_hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) + + word_hidden_states = layer_outputs[0] + + if entity_hidden_states is not None: + entity_hidden_states = layer_outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_word_hidden_states = all_word_hidden_states + (word_hidden_states,) + all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + word_hidden_states, + all_word_hidden_states, + all_self_attentions, + entity_hidden_states, + all_entity_hidden_states, + ] + if v is not None + ) + return BaseLukeModelOutput( + last_hidden_state=word_hidden_states, + hidden_states=all_word_hidden_states, + attentions=all_self_attentions, + entity_last_hidden_state=entity_hidden_states, + entity_hidden_states=all_entity_hidden_states, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class LukePooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class LukePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LukeConfig + base_model_prefix = "luke" + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + if module.embedding_dim == 1: # embedding for bias parameters + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +LUKE_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.LukeConfig`): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +LUKE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.LukeTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + + entity_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`): + Indices of entity tokens in the entity vocabulary. + + Indices can be obtained using :class:`~transformers.LukeTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + entity_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, entity_length)`, `optional`): + Mask to avoid performing attention on padding entity token indices. Mask values selected in ``[0, 1]``: + + - 1 for entity tokens that are **not masked**, + - 0 for entity tokens that are **masked**. + + entity_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`, `optional`): + Segment token indices to indicate first and second portions of the entity token inputs. Indices are + selected in ``[0, 1]``: + + - 0 corresponds to a `portion A` entity token, + - 1 corresponds to a `portion B` entity token. + + entity_position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length, max_mention_length)`, `optional`): + Indices of positions of each input entity in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any specific head on top.", + LUKE_START_DOCSTRING, +) +class LukeModel(LukePreTrainedModel): + + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = LukeEmbeddings(config) + self.entity_embeddings = LukeEntityEmbeddings(config) + self.encoder = LukeEncoder(config) + + self.pooler = LukePooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def get_entity_embeddings(self): + return self.entity_embeddings.entity_embeddings + + def set_entity_embeddings(self, value): + self.entity_embeddings.entity_embeddings = value + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError("LUKE does not support the pruning of attention heads") + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BaseLukeModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + entity_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, + entity_position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + + Returns: + + Examples:: + + >>> from transformers import LukeTokenizer, LukeModel + + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base") + >>> model = LukeModel.from_pretrained("studio-ousia/luke-base") + + # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé" + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé" + + >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**encoding) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + + # Input Wikipedia entities to obtain enriched contextualized representations of word tokens + >>> text = "Beyoncé lives in Los Angeles." + >>> entities = ["Beyoncé", "Los Angeles"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles" + >>> entity_spans = [(0, 7), (17, 28)] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles" + + >>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt") + >>> outputs = model(**encoding) + >>> word_last_hidden_state = outputs.last_hidden_state + >>> entity_last_hidden_state = outputs.entity_last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if entity_ids is not None: + entity_seq_length = entity_ids.size(1) + if entity_attention_mask is None: + entity_attention_mask = torch.ones((batch_size, entity_seq_length), device=device) + if entity_token_type_ids is None: + entity_token_type_ids = torch.zeros((batch_size, entity_seq_length), dtype=torch.long, device=device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # First, compute word embeddings + word_embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + + # Second, compute extended attention mask + extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask) + + # Third, compute entity embeddings and concatenate with word embeddings + if entity_ids is None: + entity_embedding_output = None + else: + entity_embedding_output = self.entity_embeddings(entity_ids, entity_position_ids, entity_token_type_ids) + + # Fourth, send embeddings through the model + encoder_outputs = self.encoder( + word_embedding_output, + entity_embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Fifth, get the output. LukeModel outputs the same as BertModel, namely sequence_output of shape (batch_size, seq_len, hidden_size) + sequence_output = encoder_outputs[0] + + # Sixth, we compute the pooled_output, word_sequence_output and entity_sequence_output based on the sequence_output + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseLukeModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + entity_last_hidden_state=encoder_outputs.entity_last_hidden_state, + entity_hidden_states=encoder_outputs.entity_hidden_states, + ) + + def get_extended_attention_mask( + self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor] + ): + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + word_attention_mask (:obj:`torch.LongTensor`): + Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore. + entity_attention_mask (:obj:`torch.LongTensor`, `optional`): + Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + attention_mask = word_attention_mask + if entity_attention_mask is not None: + attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1) + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})") + + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask + return incremental_indices.long() + padding_idx + + +@add_start_docstrings( + """ + The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity + token) for entity classification tasks, such as Open Entity. + """, + LUKE_START_DOCSTRING, +) +class LukeForEntityClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.num_labels = config.num_labels + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=EntityClassificationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + entity_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, + entity_position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)` or :obj:`(batch_size, num_labels)`, `optional`): + Labels for computing the classification loss. If the shape is :obj:`(batch_size,)`, the cross entropy loss + is used for the single-label classification. In this case, labels should contain the indices that should be + in :obj:`[0, ..., config.num_labels - 1]`. If the shape is :obj:`(batch_size, num_labels)`, the binary + cross entropy loss is used for the multi-label classification. In this case, labels should only contain + ``[0, 1]``, where 0 and 1 indicate false and true, respectively. + + Returns: + + Examples:: + + >>> from transformers import LukeTokenizer, LukeForEntityClassification + + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity") + >>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity") + + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé" + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + feature_vector = outputs.entity_last_hidden_state[:, 0, :] + feature_vector = self.dropout(feature_vector) + logits = self.classifier(feature_vector) + + loss = None + if labels is not None: + # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. + if labels.ndim == 1: + loss = F.cross_entropy(logits, labels) + else: + loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) + + if not return_dict: + output = ( + logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ) + return ((loss,) + output) if loss is not None else output + + return EntityClassificationOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LUKE model with a classification head on top (a linear layer on top of the hidden states of the two entity + tokens) for entity pair classification tasks, such as TACRED. + """, + LUKE_START_DOCSTRING, +) +class LukeForEntityPairClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.num_labels = config.num_labels + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels, False) + + self.init_weights() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=EntityPairClassificationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + entity_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, + entity_position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)` or :obj:`(batch_size, num_labels)`, `optional`): + Labels for computing the classification loss. If the shape is :obj:`(batch_size,)`, the cross entropy loss + is used for the single-label classification. In this case, labels should contain the indices that should be + in :obj:`[0, ..., config.num_labels - 1]`. If the shape is :obj:`(batch_size, num_labels)`, the binary + cross entropy loss is used for the multi-label classification. In this case, labels should only contain + ``[0, 1]``, where 0 and 1 indicate false and true, respectively. + + Returns: + + Examples:: + + >>> from transformers import LukeTokenizer, LukeForEntityPairClassification + + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred") + >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred") + + >>> text = "Beyoncé lives in Los Angeles." + >>> entity_spans = [(0, 7), (17, 28)] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles" + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + feature_vector = torch.cat( + [outputs.entity_last_hidden_state[:, 0, :], outputs.entity_last_hidden_state[:, 1, :]], dim=1 + ) + feature_vector = self.dropout(feature_vector) + logits = self.classifier(feature_vector) + + loss = None + if labels is not None: + # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. + if labels.ndim == 1: + loss = F.cross_entropy(logits, labels) + else: + loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) + + if not return_dict: + output = ( + logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ) + return ((loss,) + output) if loss is not None else output + + return EntityPairClassificationOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LUKE model with a span classification head on top (a linear layer on top of the hidden states output) for tasks + such as named entity recognition. + """, + LUKE_START_DOCSTRING, +) +class LukeForEntitySpanClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.num_labels = config.num_labels + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=EntitySpanClassificationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + entity_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, + entity_position_ids=None, + entity_start_positions=None, + entity_end_positions=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + entity_start_positions (:obj:`torch.LongTensor`): + The start positions of entities in the word token sequence. + + entity_end_positions (:obj:`torch.LongTensor`): + The end positions of entities in the word token sequence. + + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)` or :obj:`(batch_size, entity_length, num_labels)`, `optional`): + Labels for computing the classification loss. If the shape is :obj:`(batch_size, entity_length)`, the cross + entropy loss is used for the single-label classification. In this case, labels should contain the indices + that should be in :obj:`[0, ..., config.num_labels - 1]`. If the shape is :obj:`(batch_size, entity_length, + num_labels)`, the binary cross entropy loss is used for the multi-label classification. In this case, + labels should only contain ``[0, 1]``, where 0 and 1 indicate false and true, respectively. + + Returns: + + Examples:: + + >>> from transformers import LukeTokenizer, LukeForEntitySpanClassification + + >>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003") + >>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003") + + >>> text = "Beyoncé lives in Los Angeles" + + # List all possible entity spans in the text + >>> word_start_positions = [0, 8, 14, 17, 21] # character-based start positions of word tokens + >>> word_end_positions = [7, 13, 16, 20, 28] # character-based end positions of word tokens + >>> entity_spans = [] + >>> for i, start_pos in enumerate(word_start_positions): + ... for end_pos in word_end_positions[i:]: + ... entity_spans.append((start_pos, end_pos)) + + >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + hidden_size = outputs.last_hidden_state.size(-1) + + entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size) + start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions) + entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size) + end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions) + feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2) + + feature_vector = self.dropout(feature_vector) + logits = self.classifier(feature_vector) + + loss = None + if labels is not None: + # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary + # cross entropy is used otherwise. + if labels.ndim == 2: + loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + else: + loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) + + if not return_dict: + output = ( + logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ) + return ((loss,) + output) if loss is not None else output + + return EntitySpanClassificationOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py new file mode 100644 index 00000000000000..3fe2665dc54458 --- /dev/null +++ b/src/transformers/models/luke/tokenization_luke.py @@ -0,0 +1,1531 @@ +# coding=utf-8 +# Copyright Studio-Ouisa and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for LUKE.""" + +import itertools +import json +import os +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ... import RobertaTokenizer +from ...file_utils import add_end_docstrings, is_tf_available, is_torch_available +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + PaddingStrategy, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, + _is_tensorflow, + _is_torch, + to_py_obj, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) + +EntitySpan = Tuple[int, int] +EntitySpanInput = List[EntitySpan] +Entity = str +EntityInput = List[Entity] + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "entity_vocab_file": "entity_vocab.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "studio-ousia/luke-base": "https://huggingface.co/studio-ousia/luke-base/resolve/main/vocab.json", + "studio-ousia/luke-large": "https://huggingface.co/studio-ousia/luke-large/resolve/main/vocab.json", + }, + "merges_file": { + "studio-ousia/luke-base": "https://huggingface.co/studio-ousia/luke-base/resolve/main/merges.txt", + "studio-ousia/luke-large": "https://huggingface.co/studio-ousia/luke-large/resolve/main/merges.txt", + }, + "entity_vocab_file": { + "studio-ousia/luke-base": "https://huggingface.co/studio-ousia/luke-base/resolve/main/entity_vocab.json", + "studio-ousia/luke-large": "https://huggingface.co/studio-ousia/luke-large/resolve/main/entity_vocab.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "studio-ousia/luke-base": 512, + "studio-ousia/luke-large": 512, +} + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (:obj:`bool`, `optional`): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + return_attention_mask (:obj:`bool`, `optional`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are attention masks? <../glossary.html#attention-mask>`__ + return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return overflowing token sequences. + return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return :obj:`(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from + :class:`~transformers.PreTrainedTokenizerFast`, if using Python's tokenizer, this method will raise + :obj:`NotImplementedError`. + return_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return the lengths of the encoded inputs. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. + **kwargs: passed to the :obj:`self.tokenize()` method + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + + - **token_type_ids** -- List of token type ids to be fed to a model (when :obj:`return_token_type_ids=True` + or if `"token_type_ids"` is in :obj:`self.model_input_names`). + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + - **entity_ids** -- List of entity ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + + - **entity_position_ids** -- List of entity positions in the input sequence to be fed to a model. + + - **entity_token_type_ids** -- List of entity token type ids to be fed to a model (when + :obj:`return_token_type_ids=True` or if `"entity_token_type_ids"` is in :obj:`self.model_input_names`). + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + + - **entity_attention_mask** -- List of indices specifying which entities should be attended to by the model + (when :obj:`return_attention_mask=True` or if `"entity_attention_mask"` is in + :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + - **entity_start_positions** -- List of the start positions of entities in the word token sequence (when + :obj:`task="entity_span_classification"`). + - **entity_end_positions** -- List of the end positions of entities in the word token sequence (when + :obj:`task="entity_span_classification"`). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a :obj:`max_length` is specified and + :obj:`return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a :obj:`max_length` is specified and + :obj:`return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when :obj:`add_special_tokens=True` and :obj:`return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when :obj:`return_length=True`) + +""" + + +class LukeTokenizer(RobertaTokenizer): + r""" + Construct a LUKE tokenizer. + + This tokenizer inherits from :class:`~transformers.RobertaTokenizer` which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. Compared to + :class:`~transformers.RobertaTokenizer`, :class:`~transformers.LukeTokenizer` also creates entity sequences, namely + :obj:`entity_ids`, :obj:`entity_attention_mask`, :obj:`entity_token_type_ids`, and :obj:`entity_position_ids` to be + used by the LUKE model. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + merges_file (:obj:`str`): + Path to the merges file. + entity_vocab_file (:obj:`str`): + Path to the entity vocabulary file. + task (:obj:`str`, `optional`): + Task for which you want to prepare sequences. One of :obj:`"entity_classification"`, + :obj:`"entity_pair_classification"`, or :obj:`"entity_span_classification"`. If you specify this argument, + the entity sequence is automatically created based on the given entity span(s). + max_entity_length (:obj:`int`, `optional`, defaults to 32): + The maximum length of :obj:`entity_ids`. + max_mention_length (:obj:`int`, `optional`, defaults to 30): + The maximum number of tokens inside an entity span. + entity_token_1 (:obj:`str`, `optional`, defaults to :obj:``): + The special token used to represent an entity span in a word token sequence. This token is only used when + ``task`` is set to :obj:`"entity_classification"` or :obj:`"entity_pair_classification"`. + entity_token_2 (:obj:`str`, `optional`, defaults to :obj:``): + The special token used to represent an entity span in a word token sequence. This token is only used when + ``task`` is set to :obj:`"entity_pair_classification"`. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + merges_file, + entity_vocab_file, + task=None, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", + **kwargs + ): + # we add 2 special tokens for downstream tasks + # for more information about lstrip and rstrip, see https://github.com/huggingface/transformers/pull/2778 + entity_token_1 = ( + AddedToken(entity_token_1, lstrip=False, rstrip=False) + if isinstance(entity_token_1, str) + else entity_token_1 + ) + entity_token_2 = ( + AddedToken(entity_token_2, lstrip=False, rstrip=False) + if isinstance(entity_token_2, str) + else entity_token_2 + ) + kwargs["additional_special_tokens"] = [entity_token_1, entity_token_2] + kwargs["additional_special_tokens"] += kwargs.get("additional_special_tokens", []) + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + task=task, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", + **kwargs, + ) + + with open(entity_vocab_file, encoding="utf-8") as entity_vocab_handle: + self.entity_vocab = json.load(entity_vocab_handle) + + self.task = task + if task is None or task == "entity_span_classification": + self.max_entity_length = max_entity_length + elif task == "entity_classification": + self.max_entity_length = 1 + elif task == "entity_pair_classification": + self.max_entity_length = 2 + else: + raise ValueError( + f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification', 'entity_span_classification'] only." + ) + + self.max_mention_length = max_mention_length + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, List[TextInput]], + text_pair: Optional[Union[TextInput, List[TextInput]]] = None, + entity_spans: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None, + entity_spans_pair: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None, + entities: Optional[Union[EntityInput, List[EntityInput]]] = None, + entities_pair: Optional[Union[EntityInput, List[EntityInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences, depending on the task you want to prepare them for. + + Args: + text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify + :obj:`"entity_classification"` or :obj:`"entity_pair_classification"` as the ``task`` argument in the + constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the + length of each sequence must be equal to the length of each sequence of ``entities``. + entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify the + ``task`` argument in the constructor, this argument is ignored. If you specify ``entities_pair``, the + length of each sequence must be equal to the length of each sequence of ``entities_pair``. + entities (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the ``task`` argument in the constructor. The length + of each sequence must be equal to the length of each sequence of ``entity_spans``. If you specify + ``entity_spans`` without specifying this argument, the entity sequence or the batch of entity sequences + is automatically constructed by filling it with the [MASK] entity. + entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the ``task`` argument in the constructor. The length + of each sequence must be equal to the length of each sequence of ``entity_spans_pair``. If you specify + ``entity_spans_pair`` without specifying this argument, the entity sequence or the batch of entity + sequences is automatically constructed by filling it with the [MASK] entity. + max_entity_length (:obj:`int`, `optional`): + The maximum length of :obj:`entity_ids`. + """ + # Input type checking for clearer error + is_valid_single_text = isinstance(text, str) + is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) + assert ( + is_valid_single_text or is_valid_batch_text + ), "text input must be of type `str` (single example) or `List[str]` (batch)." + + is_valid_single_text_pair = isinstance(text_pair, str) + is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and ( + len(text_pair) == 0 or isinstance(text_pair[0], str) + ) + assert ( + text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair + ), "text_pair input must be of type `str` (single example) or `List[str]` (batch)." + + is_batched = bool(isinstance(text, (list, tuple))) + + if is_batched: + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + if entities is None: + batch_entities_or_entities_pairs = None + else: + batch_entities_or_entities_pairs = ( + list(zip(entities, entities_pair)) if entities_pair is not None else entities + ) + + if entity_spans is None: + batch_entity_spans_or_entity_spans_pairs = None + else: + batch_entity_spans_or_entity_spans_pairs = ( + list(zip(entity_spans, entity_spans_pair)) if entity_spans_pair is not None else entity_spans + ) + + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, + batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + entities=entities, + entities_pair=entities_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + .. warning:: This method is deprecated, ``__call__`` should be used instead. + + Args: + text (:obj:`str`): + The first sequence to be encoded. Each sequence must be a string. + text_pair (:obj:`str`): + The second sequence to be encoded. Each sequence must be a string. + entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`):: + The first sequence of entity spans to be encoded. The sequence consists of tuples each with two + integers denoting character-based start and end positions of entities. If you specify + :obj:`"entity_classification"` or :obj:`"entity_pair_classification"` as the ``task`` argument in the + constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the + length of the sequence must be equal to the length of ``entities``. + entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`):: + The second sequence of entity spans to be encoded. The sequence consists of tuples each with two + integers denoting character-based start and end positions of entities. If you specify the ``task`` + argument in the constructor, this argument is ignored. If you specify ``entities_pair``, the length of + the sequence must be equal to the length of ``entities_pair``. + entities (:obj:`List[str]` `optional`):: + The first sequence of entities to be encoded. The sequence consists of strings representing entities, + i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument + is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be + equal to the length of ``entity_spans``. If you specify ``entity_spans`` without specifying this + argument, the entity sequence is automatically constructed by filling it with the [MASK] entity. + entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`):: + The second sequence of entities to be encoded. The sequence consists of strings representing entities, + i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument + is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be + equal to the length of ``entity_spans_pair``. If you specify ``entity_spans_pair`` without specifying + this argument, the entity sequence is automatically constructed by filling it with the [MASK] entity. + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. + """ + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + text_pair=text_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + entities=entities, + entities_pair=entities_pair, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + + # prepare_for_model will create the attention_mask and token_type_ids + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], + batch_entity_spans_or_entity_spans_pairs: Optional[ + Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + .. warning:: + This method is deprecated, ``__call__`` should be used instead. + + + Args: + batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of string or a list of pair + of string (see details in ``encode_plus``). + batch_entity_spans_or_entity_spans_pairs (:obj:`List[List[Tuple[int, int]]]`, + :obj:`List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]`, `optional`):: + Batch of entity span sequences or pairs of entity span sequences to be encoded (see details in + ``encode_plus``). + batch_entities_or_entities_pairs (:obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, + `optional`): + Batch of entity sequences or pairs of entity sequences to be encoded (see details in ``encode_plus``). + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, + batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], + batch_entity_spans_or_entity_spans_pairs: Optional[ + Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]] + ] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + # input_ids is a list of tuples (one for each example in the batch) + input_ids = [] + entity_ids = [] + entity_token_spans = [] + for index, text_or_text_pair in enumerate(batch_text_or_text_pairs): + if not isinstance(text_or_text_pair, (list, tuple)): + text, text_pair = text_or_text_pair, None + else: + text, text_pair = text_or_text_pair + + entities, entities_pair = None, None + if batch_entities_or_entities_pairs is not None: + entities_or_entities_pairs = batch_entities_or_entities_pairs[index] + if entities_or_entities_pairs: + if isinstance(entities_or_entities_pairs[0], str): + entities, entities_pair = entities_or_entities_pairs, None + else: + entities, entities_pair = entities_or_entities_pairs + + entity_spans, entity_spans_pair = None, None + if batch_entity_spans_or_entity_spans_pairs is not None: + entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index] + if entity_spans_or_entity_spans_pairs: + if isinstance(entity_spans_or_entity_spans_pairs[0][0], int): + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None + else: + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + input_ids.append((first_ids, second_ids)) + entity_ids.append((first_entity_ids, second_entity_ids)) + entity_token_spans.append((first_entity_token_spans, second_entity_token_spans)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + batch_entity_ids_pairs=entity_ids, + batch_entity_token_spans_pairs=entity_token_spans, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + def _create_input_sequence( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + **kwargs + ) -> Tuple[list, list, list, list, list, list]: + def get_input_ids(text): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + + def get_input_ids_and_entity_token_spans(text, entity_spans): + if entity_spans is None: + return get_input_ids(text), None + + cur = 0 + input_ids = [] + entity_token_spans = [None] * len(entity_spans) + + split_char_positions = sorted(frozenset(itertools.chain(*entity_spans))) + char_pos2token_pos = {} + + for split_char_position in split_char_positions: + orig_split_char_position = split_char_position + if ( + split_char_position > 0 and text[split_char_position - 1] == " " + ): # whitespace should be prepended to the following token + split_char_position -= 1 + if cur != split_char_position: + input_ids += get_input_ids(text[cur:split_char_position]) + cur = split_char_position + char_pos2token_pos[orig_split_char_position] = len(input_ids) + + input_ids += get_input_ids(text[cur:]) + + entity_token_spans = [ + (char_pos2token_pos[char_start], char_pos2token_pos[char_end]) for char_start, char_end in entity_spans + ] + + return input_ids, entity_token_spans + + first_ids, second_ids = None, None + first_entity_ids, second_entity_ids = None, None + first_entity_token_spans, second_entity_token_spans = None, None + + if self.task is None: + unk_entity_id = self.entity_vocab["[UNK]"] + mask_entity_id = self.entity_vocab["[MASK]"] + + if entity_spans is None: + first_ids = get_input_ids(text) + else: + assert isinstance(entity_spans, list) and ( + len(entity_spans) == 0 or isinstance(entity_spans[0], tuple) + ), "entity_spans should be given as a list of tuples containing the start and end character indices" + assert entities is None or ( + isinstance(entities, list) and (len(entities) == 0 or isinstance(entities[0], str)) + ), "If you specify entities, they should be given as a list of entity names" + assert entities is None or len(entities) == len( + entity_spans + ), "If you specify entities, entities and entity_spans must be the same length" + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + if entities is None: + first_entity_ids = [mask_entity_id] * len(entity_spans) + else: + first_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities] + + if text_pair is not None: + if entity_spans_pair is None: + second_ids = get_input_ids(text_pair) + else: + assert isinstance(entity_spans_pair, list) and ( + len(entity_spans_pair) == 0 or isinstance(entity_spans_pair[0], tuple) + ), "entity_spans_pair should be given as a list of tuples containing the start and end character indices" + assert entities_pair is None or ( + isinstance(entities_pair, list) + and (len(entities_pair) == 0 or isinstance(entities_pair[0], str)) + ), "If you specify entities_pair, they should be given as a list of entity names" + assert entities_pair is None or len(entities_pair) == len( + entity_spans_pair + ), "If you specify entities_pair, entities_pair and entity_spans_pair must be the same length" + + second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans( + text_pair, entity_spans_pair + ) + if entities_pair is None: + second_entity_ids = [mask_entity_id] * len(entity_spans_pair) + else: + second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair] + + elif self.task == "entity_classification": + assert ( + isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple) + ), "Entity spans should be a list containing a single tuple containing the start and end character indices of an entity" + + first_entity_ids = [self.entity_vocab["[MASK]"]] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + # add special tokens to input ids + entity_token_start, entity_token_end = first_entity_token_spans[0] + first_ids = ( + first_ids[:entity_token_end] + [self.additional_special_tokens_ids[0]] + first_ids[entity_token_end:] + ) + first_ids = ( + first_ids[:entity_token_start] + + [self.additional_special_tokens_ids[0]] + + first_ids[entity_token_start:] + ) + first_entity_token_spans = [(entity_token_start, entity_token_end + 2)] + + elif self.task == "entity_pair_classification": + assert ( + isinstance(entity_spans, list) + and len(entity_spans) == 2 + and isinstance(entity_spans[0], tuple) + and isinstance(entity_spans[1], tuple) + ), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity" + + head_span, tail_span = entity_spans + first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + head_token_span, tail_token_span = first_entity_token_spans + token_span_with_special_token_ids = [ + (head_token_span, self.additional_special_tokens_ids[0]), + (tail_token_span, self.additional_special_tokens_ids[1]), + ] + if head_token_span[0] < tail_token_span[0]: + first_entity_token_spans[0] = (head_token_span[0], head_token_span[1] + 2) + first_entity_token_spans[1] = (tail_token_span[0] + 2, tail_token_span[1] + 4) + token_span_with_special_token_ids = reversed(token_span_with_special_token_ids) + else: + first_entity_token_spans[0] = (head_token_span[0] + 2, head_token_span[1] + 4) + first_entity_token_spans[1] = (tail_token_span[0], tail_token_span[1] + 2) + + for (entity_token_start, entity_token_end), special_token_id in token_span_with_special_token_ids: + first_ids = first_ids[:entity_token_end] + [special_token_id] + first_ids[entity_token_end:] + first_ids = first_ids[:entity_token_start] + [special_token_id] + first_ids[entity_token_start:] + + elif self.task == "entity_span_classification": + mask_entity_id = self.entity_vocab["[MASK]"] + + assert isinstance(entity_spans, list) and isinstance( + entity_spans[0], tuple + ), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity" + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + first_entity_ids = [mask_entity_id] * len(entity_spans) + + else: + raise ValueError(f"Task {self.task} not supported") + + return ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_ids_pairs: List[Tuple[List[int], None]], + batch_entity_ids_pairs: List[Tuple[Optional[List[int]], Optional[List[int]]]], + batch_entity_token_spans_pairs: List[Tuple[Optional[List[Tuple[int, int]]], Optional[List[Tuple[int, int]]]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + batch_entity_ids_pairs: list of entity ids or entity ids pairs + batch_entity_token_spans_pairs: list of entity spans or entity spans pairs + max_entity_length: The maximum length of the entity sequence. + """ + + batch_outputs = {} + for input_ids, entity_ids, entity_token_span_pairs in zip( + batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs + ): + first_ids, second_ids = input_ids + first_entity_ids, second_entity_ids = entity_ids + first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs + outputs = self.prepare_for_model( + first_ids, + second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + entity_ids: Optional[List[int]] = None, + pair_entity_ids: Optional[List[int]] = None, + entity_token_spans: Optional[List[Tuple[int, int]]] = None, + pair_entity_token_spans: Optional[List[Tuple[int, int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + """ + Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids, + entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing + while taking into account the special tokens and manages a moving window (with user defined stride) for + overflowing tokens + + + Args: + ids (:obj:`List[int]`): + Tokenized input ids of the first sequence. + pair_ids (:obj:`List[int]`, `optional`): + Tokenized input ids of the second sequence. + entity_ids (:obj:`List[int]`, `optional`): + Entity ids of the first sequence. + pair_entity_ids (:obj:`List[int]`, `optional`): + Entity ids of the second sequence. + entity_token_spans (:obj:`List[Tuple[int, int]]`, `optional`): + Entity spans of the first sequence. + pair_entity_token_spans (:obj:`List[Tuple[int, int]]`, `optional`): + Entity spans of the second sequence. + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + # Compute lengths + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned word encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length and max_entity_length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + # truncate words up to max_length + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + entity_token_offset = 1 # 1 * token + pair_entity_token_offset = len(ids) + 3 # 1 * token & 2 * tokens + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + entity_token_offset = 0 + pair_entity_token_offset = len(ids) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Set max entity length + if not max_entity_length: + max_entity_length = self.max_entity_length + + if entity_ids is not None: + total_entity_len = 0 + num_invalid_entities = 0 + valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)] + valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)] + + total_entity_len += len(valid_entity_ids) + num_invalid_entities += len(entity_ids) - len(valid_entity_ids) + + valid_pair_entity_ids, valid_pair_entity_token_spans = None, None + if pair_entity_ids is not None: + valid_pair_entity_ids = [ + ent_id + for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans) + if span[1] <= len(pair_ids) + ] + valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)] + total_entity_len += len(valid_pair_entity_ids) + num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids) + + if num_invalid_entities != 0: + logger.warning( + f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens" + ) + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length: + # truncate entities up to max_entity_length + valid_entity_ids, valid_pair_entity_ids, overflowing_entities = self.truncate_sequences( + valid_entity_ids, + pair_ids=valid_pair_entity_ids, + num_tokens_to_remove=total_entity_len - max_entity_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + valid_entity_token_spans = valid_entity_token_spans[: len(valid_entity_ids)] + if valid_pair_entity_token_spans is not None: + valid_pair_entity_token_spans = valid_pair_entity_token_spans[: len(valid_pair_entity_ids)] + + if return_overflowing_tokens: + encoded_inputs["overflowing_entities"] = overflowing_entities + encoded_inputs["num_truncated_entities"] = total_entity_len - max_entity_length + + final_entity_ids = valid_entity_ids + valid_pair_entity_ids if valid_pair_entity_ids else valid_entity_ids + encoded_inputs["entity_ids"] = list(final_entity_ids) + entity_position_ids = [] + entity_start_positions = [] + entity_end_positions = [] + for (token_spans, offset) in ( + (valid_entity_token_spans, entity_token_offset), + (valid_pair_entity_token_spans, pair_entity_token_offset), + ): + if token_spans is not None: + for start, end in token_spans: + start += offset + end += offset + position_ids = list(range(start, end))[: self.max_mention_length] + position_ids += [-1] * (self.max_mention_length - end + start) + entity_position_ids.append(position_ids) + entity_start_positions.append(start) + entity_end_positions.append(end - 1) + + encoded_inputs["entity_position_ids"] = entity_position_ids + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = entity_start_positions + encoded_inputs["entity_end_positions"] = entity_end_positions + + if return_token_type_ids: + encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"]) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + # To do: add padding of entities + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + List[BatchEncoding], + Dict[str, EncodedInput], + Dict[str, List[EncodedInput]], + List[Dict[str, EncodedInput]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with + ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``) .. note:: If the + ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result + will use the same type unless you provide a different tensor type with ``return_tensors``. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + Args: + encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`): + Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str, + List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str, + List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as + well as in a PyTorch Dataloader collate function. Instead of :obj:`List[int]` you can have tensors + (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_attention_mask (:obj:`bool`, `optional`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. `What are + attention masks? <../glossary.html#attention-mask>`__ + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. + """ + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} + + # The model's main input name, usually `input_ids`, has be passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method" + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + if not required_input: + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + index = 0 + while len(required_input[index]) == 0: + index += 1 + if index < len(required_input): + first_element = required_input[index][0] + # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. + if not isinstance(first_element, (int, list, tuple)): + if is_tf_available() and _is_tensorflow(first_element): + return_tensors = "tf" if return_tensors is None else return_tensors + elif is_torch_available() and _is_torch(first_element): + return_tensors = "pt" if return_tensors is None else return_tensors + elif isinstance(first_element, np.ndarray): + return_tensors = "np" if return_tensors is None else return_tensors + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + f"Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in encoded_inputs.items(): + encoded_inputs[key] = to_py_obj(value) + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + if max_entity_length is None: + max_entity_length = self.max_entity_length + + required_input = encoded_inputs[self.model_input_names[0]] + if required_input and not isinstance(required_input[0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(required_input) + assert all( + len(v) == batch_size for v in encoded_inputs.values() + ), "Some items in the output dictionary have a different batch size than others." + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + max_entity_length = ( + max(len(inputs) for inputs in encoded_inputs["entity_ids"]) if "entity_ids" in encoded_inputs else 0 + ) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) + outputs = self._pad( + inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + + Args: + encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + max_entity_length: The maximum length of the entity sequence. + padding_strategy: PaddingStrategy to use for padding. + + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + entities_provided = bool("entity_ids" in encoded_inputs) + + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) + if entities_provided: + max_entity_length = len(encoded_inputs["entity_ids"]) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if ( + entities_provided + and max_entity_length is not None + and pad_to_multiple_of is not None + and (max_entity_length % pad_to_multiple_of != 0) + ): + max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ( + len(encoded_inputs["input_ids"]) != max_length + or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length) + ) + + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + if entities_provided: + entity_difference = max_entity_length - len(encoded_inputs["entity_ids"]) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference + if entities_provided: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + [ + 0 + ] * entity_difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference + if entities_provided: + encoded_inputs["entity_token_type_ids"] = ( + encoded_inputs["entity_token_type_ids"] + [0] * entity_difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + if entities_provided: + encoded_inputs["entity_ids"] = encoded_inputs["entity_ids"] + [0] * entity_difference + encoded_inputs["entity_position_ids"] = ( + encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference + ) + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = ( + encoded_inputs["entity_start_positions"] + [0] * entity_difference + ) + encoded_inputs["entity_end_positions"] = ( + encoded_inputs["entity_end_positions"] + [0] * entity_difference + ) + + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) + if entities_provided: + encoded_inputs["entity_attention_mask"] = [0] * entity_difference + [1] * len( + encoded_inputs["entity_ids"] + ) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"] + if entities_provided: + encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[ + "entity_token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] + if entities_provided: + encoded_inputs["entity_ids"] = [0] * entity_difference + encoded_inputs["entity_ids"] + encoded_inputs["entity_position_ids"] = [ + [-1] * self.max_mention_length + ] * entity_difference + encoded_inputs["entity_position_ids"] + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_start_positions" + ] + encoded_inputs["entity_end_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_end_positions" + ] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + if entities_provided: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + + return encoded_inputs + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + vocab_file, merge_file = super().save_vocabulary(save_directory, filename_prefix) + + entity_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"] + ) + + with open(entity_vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.entity_vocab, ensure_ascii=False)) + + return vocab_file, merge_file, entity_vocab_file diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2f160a881bb161..56fc19f7116ad4 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -801,7 +801,9 @@ def __init__(self, verbose=True, **kwargs): if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == "additional_special_tokens": assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" - assert all(isinstance(t, str) for t in value), "One of the tokens is not a string" + assert all( + isinstance(t, (str, AddedToken)) for t in value + ), "One of the tokens is not a string or an AddedToken" setattr(self, key, value) elif isinstance(value, (str, AddedToken)): setattr(self, key, value) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2a24b845748a67..47c80380a83254 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1739,6 +1739,42 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class LukeForEntityClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukeForEntityPairClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukeForEntitySpanClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukeModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukePreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LxmertEncoder: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_luke.py b/tests/test_modeling_luke.py new file mode 100644 index 00000000000000..ab4879a716b605 --- /dev/null +++ b/tests/test_modeling_luke.py @@ -0,0 +1,609 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch LUKE model. """ + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + LukeConfig, + LukeForEntityClassification, + LukeForEntityPairClassification, + LukeForEntitySpanClassification, + LukeModel, + LukeTokenizer, + ) + from transformers.models.luke.modeling_luke import LUKE_PRETRAINED_MODEL_ARCHIVE_LIST + + +class LukeModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + entity_length=3, + mention_length=5, + use_attention_mask=True, + use_token_type_ids=True, + use_entity_ids=True, + use_entity_attention_mask=True, + use_entity_token_type_ids=True, + use_entity_position_ids=True, + use_labels=True, + vocab_size=99, + entity_vocab_size=10, + entity_emb_size=6, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_entity_classification_labels=9, + num_entity_pair_classification_labels=6, + num_entity_span_classification_labels=4, + use_entity_aware_attention=True, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.entity_length = entity_length + self.mention_length = mention_length + self.use_attention_mask = use_attention_mask + self.use_token_type_ids = use_token_type_ids + self.use_entity_ids = use_entity_ids + self.use_entity_attention_mask = use_entity_attention_mask + self.use_entity_token_type_ids = use_entity_token_type_ids + self.use_entity_position_ids = use_entity_position_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.entity_vocab_size = entity_vocab_size + self.entity_emb_size = entity_emb_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_entity_classification_labels = num_entity_classification_labels + self.num_entity_pair_classification_labels = num_entity_pair_classification_labels + self.num_entity_span_classification_labels = num_entity_span_classification_labels + self.scope = scope + self.use_entity_aware_attention = use_entity_aware_attention + + self.encoder_seq_length = seq_length + self.key_length = seq_length + self.num_hidden_states_types = 2 # hidden_states and entity_hidden_states + + def prepare_config_and_inputs(self): + # prepare words + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + # prepare entities + entity_ids = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size) + + entity_attention_mask = None + if self.use_entity_attention_mask: + entity_attention_mask = random_attention_mask([self.batch_size, self.entity_length]) + + entity_token_type_ids = None + if self.use_token_type_ids: + entity_token_type_ids = ids_tensor([self.batch_size, self.entity_length], self.type_vocab_size) + + entity_position_ids = None + if self.use_entity_position_ids: + entity_position_ids = ids_tensor( + [self.batch_size, self.entity_length, self.mention_length], self.mention_length + ) + + sequence_labels = None + entity_classification_labels = None + entity_pair_classification_labels = None + entity_span_classification_labels = None + + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels) + entity_pair_classification_labels = ids_tensor( + [self.batch_size], self.num_entity_pair_classification_labels + ) + entity_span_classification_labels = ids_tensor( + [self.batch_size, self.entity_length], self.num_entity_span_classification_labels + ) + + config = LukeConfig( + vocab_size=self.vocab_size, + entity_vocab_size=self.entity_vocab_size, + entity_emb_size=self.entity_emb_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + use_entity_aware_attention=self.use_entity_aware_attention, + ) + + return ( + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + model = LukeModel(config=config) + model.to(torch_device) + model.eval() + # test with words + entities + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + ) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual( + result.entity_last_hidden_state.shape, (self.batch_size, self.entity_length, self.hidden_size) + ) + + # test with words only + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_entity_classification( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_labels = self.num_entity_classification_labels + model = LukeForEntityClassification(config) + model.to(torch_device) + model.eval() + + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + labels=entity_classification_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_entity_classification_labels)) + + def create_and_check_for_entity_pair_classification( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_labels = self.num_entity_pair_classification_labels + model = LukeForEntityClassification(config) + model.to(torch_device) + model.eval() + + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + labels=entity_pair_classification_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_entity_pair_classification_labels)) + + def create_and_check_for_entity_span_classification( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_labels = self.num_entity_span_classification_labels + model = LukeForEntitySpanClassification(config) + model.to(torch_device) + model.eval() + + entity_start_positions = ids_tensor([self.batch_size, self.entity_length], self.seq_length) + entity_end_positions = ids_tensor([self.batch_size, self.entity_length], self.seq_length) + + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + entity_start_positions=entity_start_positions, + entity_end_positions=entity_end_positions, + labels=entity_span_classification_labels, + ) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + "entity_ids": entity_ids, + "entity_token_type_ids": entity_token_type_ids, + "entity_attention_mask": entity_attention_mask, + "entity_position_ids": entity_position_ids, + } + return config, inputs_dict + + +@require_torch +class LukeModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + LukeModel, + LukeForEntityClassification, + LukeForEntityPairClassification, + LukeForEntitySpanClassification, + ) + if is_torch_available() + else () + ) + test_pruning = False + test_torchscript = False + test_resize_embeddings = True + test_head_masking = True + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if model_class == LukeForEntitySpanClassification: + inputs_dict["entity_start_positions"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device + ) + inputs_dict["entity_end_positions"] = torch.ones( + (self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device + ) + + if return_labels: + if model_class in (LukeForEntityClassification, LukeForEntityPairClassification): + inputs_dict["labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + elif model_class == LukeForEntitySpanClassification: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.entity_length), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def setUp(self): + self.model_tester = LukeModelTester(self) + self.config_tester = ConfigTester(self, config_class=LukeConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in LUKE_PRETRAINED_MODEL_ARCHIVE_LIST: + model = LukeModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_for_entity_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_entity_classification(*config_and_inputs) + + def test_for_entity_pair_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_entity_pair_classification(*config_and_inputs) + + def test_for_entity_span_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_entity_span_classification(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_length = self.model_tester.seq_length + entity_length = self.model_tester.entity_length + key_length = seq_length + entity_length + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length + entity_length, key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = self.model_tester.num_hidden_states_types + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length + entity_length, key_length], + ) + + def test_entity_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + entity_hidden_states = outputs.entity_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(entity_hidden_states), expected_num_layers) + + entity_length = self.model_tester.entity_length + + self.assertListEqual( + list(entity_hidden_states[0].shape[-2:]), + [entity_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_retain_grad_entity_hidden_states(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + entity_hidden_states = outputs.entity_hidden_states[0] + entity_hidden_states.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(entity_hidden_states.grad) + + +@require_torch +class LukeModelIntegrationTests(unittest.TestCase): + @slow + def test_inference_base_model(self): + model = LukeModel.from_pretrained("studio-ousia/luke-base").eval() + model.to(torch_device) + + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification") + text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." + span = (39, 42) + encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt") + + # move all values to device + for key, value in encoding.items(): + encoding[key] = encoding[key].to(torch_device) + + outputs = model(**encoding) + + # Verify word hidden states + expected_shape = torch.Size((1, 42, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.0037, 0.1368, -0.0091], [0.1099, 0.3329, -0.1095], [0.0765, 0.5335, 0.1179]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + + # Verify entity hidden states + expected_shape = torch.Size((1, 1, 768)) + self.assertEqual(outputs.entity_last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]]) + self.assertTrue(torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_large_model(self): + model = LukeModel.from_pretrained("studio-ousia/luke-large").eval() + model.to(torch_device) + + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large", task="entity_classification") + text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." + span = (39, 42) + encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt") + + # move all values to device + for key, value in encoding.items(): + encoding[key] = encoding[key].to(torch_device) + + outputs = model(**encoding) + + # Verify word hidden states + expected_shape = torch.Size((1, 42, 1024)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.0133, 0.0865, 0.0095], [0.3093, -0.2576, -0.7418], [-0.1720, -0.2117, -0.2869]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + + # Verify entity hidden states + expected_shape = torch.Size((1, 1, 1024)) + self.assertEqual(outputs.entity_last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor([[0.0466, -0.0106, -0.0179]]) + self.assertTrue(torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/test_tokenization_luke.py b/tests/test_tokenization_luke.py new file mode 100644 index 00000000000000..ee5af69eef1261 --- /dev/null +++ b/tests/test_tokenization_luke.py @@ -0,0 +1,575 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import AddedToken, LukeTokenizer +from transformers.testing_utils import require_torch, slow + +from .test_tokenization_common import TokenizerTesterMixin + + +class Luke(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = LukeTokenizer + from_pretrained_kwargs = {"cls_token": ""} + + def setUp(self): + super().setUp() + + self.special_tokens_map = {"entity_token_1": "", "entity_token_2": ""} + + def get_tokenizer(self, task=None, **kwargs): + kwargs.update(self.special_tokens_map) + return self.tokenizer_class.from_pretrained("studio-ousia/luke-base", task=task, **kwargs) + + def get_input_output_texts(self, tokenizer): + input_text = "lower newer" + output_text = "lower newer" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-base") + text = "lower newer" + bpe_tokens = ["lower", "\u0120newer"] + tokens = tokenizer.tokenize(text) # , add_prefix_space=True) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [tokenizer.unk_token] + input_bpe_tokens = [29668, 13964, 3] + self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + def luke_dict_integration_testing(self): + tokenizer = self.get_tokenizer() + + self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [0, 31414, 232, 328, 2]) + self.assertListEqual( + tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False), + [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], + ) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-large") + + text = tokenizer.encode("sequence builders", add_special_tokens=False) + text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) + + encoded_text_from_decode = tokenizer.encode( + "sequence builders", add_special_tokens=True, add_prefix_space=False + ) + encoded_pair_from_decode = tokenizer.encode( + "sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False + ) + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_sentence == encoded_text_from_decode + assert encoded_pair == encoded_pair_from_decode + + def test_space_encoding(self): + tokenizer = self.get_tokenizer() + + sequence = "Encode this sequence." + space_encoding = tokenizer.byte_encoder[" ".encode("utf-8")[0]] + + # Testing encoder arguments + encoded = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=False) + first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0] + self.assertNotEqual(first_char, space_encoding) + + encoded = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True) + first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0] + self.assertEqual(first_char, space_encoding) + + tokenizer.add_special_tokens({"bos_token": ""}) + encoded = tokenizer.encode(sequence, add_special_tokens=True) + first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0] + self.assertNotEqual(first_char, space_encoding) + + # Testing spaces after special tokens + mask = "" + tokenizer.add_special_tokens( + {"mask_token": AddedToken(mask, lstrip=True, rstrip=False)} + ) # mask token has a left space + mask_ind = tokenizer.convert_tokens_to_ids(mask) + + sequence = "Encode sequence" + sequence_nospace = "Encode sequence" + + encoded = tokenizer.encode(sequence) + mask_loc = encoded.index(mask_ind) + first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0] + self.assertEqual(first_char, space_encoding) + + encoded = tokenizer.encode(sequence_nospace) + mask_loc = encoded.index(mask_ind) + first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0] + self.assertNotEqual(first_char, space_encoding) + + def test_pretokenized_inputs(self): + pass + + def test_embeded_special_tokens(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + sentence = "A, AllenNLP sentence." + tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True) + tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True) + + # token_type_ids should put 0 everywhere + self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) + + # token_type_ids should put 0 everywhere + self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) + + # attention_mask should put 1 everywhere, so sum over length should be 1 + self.assertEqual( + sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]), + ) + + tokens_p_str = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"]) + + # Rust correctly handles the space before the mask while python doesnt + self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2]) + + self.assertSequenceEqual( + tokens_p_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""] + ) + + +@require_torch +class LukeTokenizerIntegrationTests(unittest.TestCase): + tokenizer_class = LukeTokenizer + from_pretrained_kwargs = {"cls_token": ""} + + def setUp(self): + super().setUp() + + def test_single_text_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday", "Dummy Entity"] + spans = [(9, 21), (30, 38), (39, 42)] + + encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:6], spaces_between_special_tokens=False), " Ana Ivanovic" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][8:9], spaces_between_special_tokens=False), " Thursday" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][9:10], spaces_between_special_tokens=False), " she") + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["Ana Ivanovic"], + tokenizer.entity_vocab["Thursday"], + tokenizer.entity_vocab["[UNK]"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ) + # fmt: on + + def test_single_text_only_entity_spans_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + spans = [(9, 21), (30, 38), (39, 42)] + + encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:6], spaces_between_special_tokens=False), " Ana Ivanovic" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][8:9], spaces_between_special_tokens=False), " Thursday" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][9:10], spaces_between_special_tokens=False), " she") + + mask_id = tokenizer.entity_vocab["[MASK]"] + self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id]) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], + [9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ] + ] + ) + # fmt: on + + def test_single_text_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday", "Dummy Entity"] + spans = [(9, 21), (30, 38), (39, 42)] + + encoding = tokenizer( + sentence, + entities=entities, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + + def test_text_pair_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday" + sentence_pair = "She could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday"] + entities_pair = ["Dummy Entity"] + spans = [(9, 21), (30, 38)] + spans_pair = [(0, 3)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + ) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on ThursdayShe could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:6], spaces_between_special_tokens=False), " Ana Ivanovic" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][8:9], spaces_between_special_tokens=False), " Thursday" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][11:12], spaces_between_special_tokens=False), "She") + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["Ana Ivanovic"], + tokenizer.entity_vocab["Thursday"], + tokenizer.entity_vocab["[UNK]"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ) + # fmt: on + + def test_text_pair_only_entity_spans_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday" + sentence_pair = "She could hardly believe her luck." + spans = [(9, 21), (30, 38)] + spans_pair = [(0, 3)] + + encoding = tokenizer( + sentence, + sentence_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + ) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on ThursdayShe could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:6], spaces_between_special_tokens=False), " Ana Ivanovic" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][8:9], spaces_between_special_tokens=False), " Thursday" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][11:12], spaces_between_special_tokens=False), "She") + + mask_id = tokenizer.entity_vocab["[MASK]"] + self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id]) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ) + # fmt: on + + def test_text_pair_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday" + sentence_pair = "She could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday"] + entities_pair = ["Dummy Entity"] + spans = [(9, 21), (30, 38)] + spans_pair = [(0, 3)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + + def test_entity_classification_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification") + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." + span = (39, 42) + + encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True) + + # test words + self.assertEqual(len(encoding["input_ids"]), 42) + self.assertEqual(len(encoding["attention_mask"]), 42) + self.assertEqual(len(encoding["token_type_ids"]), 42) + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][9:12], spaces_between_special_tokens=False), " she" + ) + + # test entities + self.assertEqual(encoding["entity_ids"], [2]) + self.assertEqual(encoding["entity_attention_mask"], [1]) + self.assertEqual(encoding["entity_token_type_ids"], [0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] + ) + # fmt: on + + def test_entity_classification_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_classification", return_token_type_ids=True + ) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." + # entity information + span = (39, 42) + + encoding = tokenizer( + sentence, entity_spans=[span], return_token_type_ids=True, padding="max_length", return_tensors="pt" + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 512)) + self.assertEqual(encoding["attention_mask"].shape, (1, 512)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 512)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 1)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 1)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 1)) + self.assertEqual( + encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length) + ) + + def test_entity_pair_classification_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_pair_classification", return_token_type_ids=True + ) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + # head and tail information + spans = [(9, 21), (39, 42)] + + encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:8], spaces_between_special_tokens=False), + " Ana Ivanovic", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][11:14], spaces_between_special_tokens=False), " she" + ) + + self.assertEqual(encoding["entity_ids"], [2, 3]) + self.assertEqual(encoding["entity_attention_mask"], [1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ) + # fmt: on + + def test_entity_pair_classification_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_pair_classification", return_token_type_ids=True + ) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + # head and tail information + spans = [(9, 21), (39, 42)] + + encoding = tokenizer( + sentence, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 2)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 2)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 2)) + self.assertEqual( + encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length) + ) + + def test_entity_span_classification_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_span_classification", return_token_type_ids=True + ) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + spans = [(0, 8), (9, 21), (39, 42)] + + encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + ) + + self.assertEqual(encoding["entity_ids"], [2, 2, 2]) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ) + # fmt: on + self.assertEqual(encoding["entity_start_positions"], [1, 3, 9]) + self.assertEqual(encoding["entity_end_positions"], [2, 5, 9]) + + def test_entity_span_classification_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_span_classification", return_token_type_ids=True + ) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + spans = [(0, 8), (9, 21), (39, 42)] + + encoding = tokenizer( + sentence, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + self.assertEqual(encoding["entity_start_positions"].shape, (1, 16)) + self.assertEqual(encoding["entity_end_positions"].shape, (1, 16)) diff --git a/utils/check_repo.py b/utils/check_repo.py index bd6c9af45b23e2..019a30893db5c5 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -89,6 +89,9 @@ "DPRSpanPredictor", "FlaubertForQuestionAnswering", "GPT2DoubleHeadsModel", + "LukeForEntityClassification", + "LukeForEntityPairClassification", + "LukeForEntitySpanClassification", "OpenAIGPTDoubleHeadsModel", "RagModel", "RagSequenceForGeneration",