diff --git a/.circleci/config.yml b/.circleci/config.yml index ec9c5741fb24a1..9435b90f27862d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -277,7 +277,7 @@ jobs: - v0.4-custom_tokenizers-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[ja,testing,sentencepiece] + - run: pip install .[ja,testing,sentencepiece,jieba] - run: python -m unidic download - save_cache: key: v0.4-custom_tokenizers-{{ checksum "setup.py" }} diff --git a/README.md b/README.md index 372492d329e81b..18b2eff45b6cdf 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[BORT](https://huggingface.co/transformers/model_doc/bort.html)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry. 1. **[CamemBERT](https://huggingface.co/transformers/model_doc/camembert.html)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. 1. **[ConvBERT](https://huggingface.co/transformers/model_doc/convbert.html)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. +1. **[CPM](https://huggingface.co/transformers/model_doc/cpm.html)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. 1. **[CTRL](https://huggingface.co/transformers/model_doc/ctrl.html)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[DeBERTa](https://huggingface.co/transformers/model_doc/deberta.html)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/transformers/model_doc/deberta_v2.html)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. diff --git a/docs/source/index.rst b/docs/source/index.rst index 6bb157ce988982..ebf09989e682e3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -114,128 +114,133 @@ and conversion utilities for the following models: 11. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with Span-based Dynamic Convolution `__ by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. -12. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language +12. :doc:`CPM ` (from Tsinghua University) released with the paper `CPM: A Large-scale Generative + Chinese Pre-trained Language Model `__ by Zhengyan Zhang, Xu Han, Hao Zhou, Pei + Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, + Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, + Juanzi Li, Xiaoyan Zhu, Maosong Sun. +13. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language Model for Controllable Generation `__ by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. -13. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with +14. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -14. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT +15. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -15. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +16. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -16. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +17. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -17. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +18. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -18. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +19. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -19. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +20. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -20. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +21. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -21. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +22. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -22. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +23. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -23. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo +24. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. -24. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +25. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -25. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +26. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -26. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +27. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -27. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +28. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -28. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +29. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -29. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +30. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -30. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +31. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -31. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +32. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -32. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +33. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -33. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +34. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -34. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +35. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -35. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +36. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -36. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +37. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -37. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +38. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -38. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +39. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -39. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +40. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -40. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +41. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -41. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +42. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -42. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +43. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -43. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +44. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -44. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +45. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -45. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +46. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -46. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +47. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -47. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +48. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -48. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +49. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -49. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +50. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -50. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +51. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -51. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +52. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -52. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +53. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -437,6 +442,7 @@ TensorFlow and/or Flax. model_doc/bort model_doc/camembert model_doc/convbert + model_doc/cpm model_doc/ctrl model_doc/deberta model_doc/deberta_v2 diff --git a/docs/source/model_doc/cpm.rst b/docs/source/model_doc/cpm.rst new file mode 100644 index 00000000000000..e1380f4a933d4b --- /dev/null +++ b/docs/source/model_doc/cpm.rst @@ -0,0 +1,44 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +CPM +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The CPM model was proposed in `CPM: A Large-scale Generative Chinese Pre-trained Language Model +`__ by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, +Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, +Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. + +The abstract from the paper is the following: + +*Pre-trained Language Models (PLMs) have proven to be beneficial for various downstream NLP tasks. Recently, GPT-3, +with 175 billion parameters and 570GB training data, drew a lot of attention due to the capacity of few-shot (even +zero-shot) learning. However, applying GPT-3 to address Chinese NLP tasks is still challenging, as the training corpus +of GPT-3 is primarily English, and the parameters are not publicly available. In this technical report, we release the +Chinese Pre-trained Language Model (CPM) with generative pre-training on large-scale Chinese training data. To the best +of our knowledge, CPM, with 2.6 billion parameters and 100GB Chinese training data, is the largest Chinese pre-trained +language model, which could facilitate several downstream Chinese NLP tasks, such as conversation, essay generation, +cloze test, and language understanding. Extensive experiments demonstrate that CPM achieves strong performance on many +NLP tasks in the settings of few-shot (even zero-shot) learning.* + +The original implementation can be found here: https://github.com/TsinghuaAI/CPM-Generate + +Note: We only have a tokenizer here, since the model architecture is the same as GPT-2. + +CpmTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.CpmTokenizer + :members: diff --git a/setup.py b/setup.py index 1b2ab5bf3187ce..c403f1f33af1b5 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,7 @@ "isort>=5.5.4", "jax>=0.2.8", "jaxlib>=0.1.59", + "jieba", "keras2onnx", "nltk", "numpy>=1.17", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9108904b9c92b6..f71e075eaaed40 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -163,6 +163,7 @@ ], "models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], + "models.cpm": ["CpmTokenizer"], "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], "models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], @@ -1501,6 +1502,7 @@ ) from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer + from .models.cpm import CpmTokenizer from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index bd070d7bdf254f..82968ff299491a 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -21,6 +21,7 @@ "isort": "isort>=5.5.4", "jax": "jax>=0.2.8", "jaxlib": "jaxlib>=0.1.59", + "jieba": "jieba", "keras2onnx": "keras2onnx", "nltk": "nltk", "numpy": "numpy>=1.17", diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 97b8c8de890faa..0092c46a976768 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -30,6 +30,7 @@ blenderbot_small, camembert, convbert, + cpm, ctrl, deberta, dialogpt, diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 212c32cb4a6d83..13089e21171c0e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -115,6 +115,7 @@ from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer from ..big_bird.tokenization_big_bird import BigBirdTokenizer from ..camembert.tokenization_camembert import CamembertTokenizer + from ..cpm.tokenization_cpm import CpmTokenizer from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer from ..m2m_100 import M2M100Tokenizer from ..marian.tokenization_marian import MarianTokenizer @@ -134,6 +135,7 @@ BertGenerationTokenizer = None BigBirdTokenizer = None CamembertTokenizer = None + CpmTokenizer = None DebertaV2Tokenizer = None MarianTokenizer = None MBartTokenizer = None @@ -273,6 +275,7 @@ NO_CONFIG_TOKENIZER = [ BertJapaneseTokenizer, BertweetTokenizer, + CpmTokenizer, HerbertTokenizer, HerbertTokenizerFast, PhobertTokenizer, diff --git a/src/transformers/models/cpm/__init__.py b/src/transformers/models/cpm/__init__.py new file mode 100644 index 00000000000000..8c687ad8fc56e9 --- /dev/null +++ b/src/transformers/models/cpm/__init__.py @@ -0,0 +1,48 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule + + +_import_structure = { + "tokenization_cpm": ["CpmTokenizer"], +} + + +if TYPE_CHECKING: + from .tokenization_cpm import CpmTokenizer + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/cpm/tokenization_cpm.py b/src/transformers/models/cpm/tokenization_cpm.py new file mode 100644 index 00000000000000..447b86b1294363 --- /dev/null +++ b/src/transformers/models/cpm/tokenization_cpm.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" +from ...utils import logging +from ..xlnet.tokenization_xlnet import XLNetTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "TsinghuaAI/CPM-Generate": "https://huggingface.co/TsinghuaAI/CPM-Generate/resolve/main/spiece.model", + } +} + + +class CpmTokenizer(XLNetTokenizer): + """Runs pre-tokenization with Jieba segmentation tool. It is used in CPM models.""" + + def __init__(self, *args, **kwargs): + """ + Construct a CPM tokenizer. Based on `Jieba ` and `SentencePiece + `__. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + `SentencePiece `__ file (generally has a .spm extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to lowercase the input when tokenizing. + remove_space (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to strip the text when tokenizing (removing excess spaces before and after the string). + keep_accents (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to keep accents when tokenizing. + bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier + token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the beginning + of sequence. The token used is the :obj:`cls_token`. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sequence token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the end of + sequence. The token used is the :obj:`sep_token`. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be + this token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. It is also used as the + last token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["", ""]`): + Additional special tokens used by the tokenizer. + + Attributes: + sp_model (:obj:`SentencePieceProcessor`): + The `SentencePiece` processor that is used for every conversion (string, tokens and IDs). + """ + super().__init__(*args, **kwargs) + try: + import jieba + except ModuleNotFoundError as error: + raise error.__class__( + "You need to install jieba to use CpmTokenizer." + "See https://pypi.org/project/jieba/ for installation." + ) + self.jieba = jieba + self.translator = str.maketrans(" \n", "\u2582\u2583") + + def _tokenize(self, text, *args, **kwargs): + text = [x.translate(self.translator) for x in self.jieba.cut(text, cut_all=False)] + text = " ".join(text) + return super()._tokenize(text, *args, **kwargs) + + def _decode(self, *args, **kwargs): + text = super()._decode(*args, **kwargs) + text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n") + return text diff --git a/tests/test_tokenization_cpm.py b/tests/test_tokenization_cpm.py new file mode 100644 index 00000000000000..c65e8f07528d0e --- /dev/null +++ b/tests/test_tokenization_cpm.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.models.cpm.tokenization_cpm import CpmTokenizer +from transformers.testing_utils import custom_tokenizers + +from .test_modeling_xlnet import XLNetModelTest + + +@custom_tokenizers +class CpmTokenizationTest(XLNetModelTest): + def test_pre_tokenization(self): + tokenizer = CpmTokenizer.from_pretrained("TsinghuaAI/CPM-Generate") + text = "Hugging Face大法好,谁用谁知道。" + normalized_text = "Hugging Face大法好,谁用谁知道。" + bpe_tokens = "▁Hu gg ing ▁ ▂ ▁F ace ▁大法 ▁好 ▁ , ▁谁 ▁用 ▁谁 ▁知 道 ▁ 。".split() + + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [tokenizer.unk_token] + + input_bpe_tokens = [13789, 13283, 1421, 8, 10, 1164, 13608, 16528, 63, 8, 9, 440, 108, 440, 121, 90, 8, 12, 0] + self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + reconstructed_text = tokenizer.decode(input_bpe_tokens) + self.assertEqual(reconstructed_text, normalized_text)