From f5f430e5c80b85b57bb910435e45d84746210133 Mon Sep 17 00:00:00 2001 From: Shinya Otani <67080255+SO0529@users.noreply.github.com> Date: Wed, 14 Sep 2022 23:17:40 +0900 Subject: [PATCH] Add support for Japanese GPT-NeoX-based model by ABEJA, Inc. (#18814) * add gpt-neox-japanese model and tokenizer as new model * Correction to PR's comment for GPT NeoX Japanese - Fix to be able to use gpu - Add comment # Copied... at the top of RotaryEmbedding - Implement nn.Linear instead of original linear class - Add generation test under @slow * fix bias treatment for gpt-neox-japanese * Modidy gpt-neox-japanese following PR - add doc for bias_dropout_add - style change following a PR comment * add document for gpt-neox-japanese * remove unused import from gpt-neox-japanese * fix README for gpt-neox-japanese --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.mdx | 2 + .../source/en/model_doc/gpt_neox_japanese.mdx | 66 ++ src/transformers/__init__.py | 20 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/tokenization_auto.py | 1 + .../models/gpt_neox_japanese/__init__.py | 66 ++ .../configuration_gpt_neox_japanese.py | 124 +++ .../modeling_gpt_neox_japanese.py | 724 ++++++++++++++++++ .../tokenization_gpt_neox_japanese.py | 379 +++++++++ src/transformers/utils/dummy_pt_objects.py | 31 + .../utils/dummy_tokenizers_objects.py | 7 + tests/models/gpt_neox_japanese/__init__.py | 0 .../test_modeling_gpt_neox_japanese.py | 255 ++++++ .../test_tokenization_gpt_neox_japanese.py | 137 ++++ 21 files changed, 1825 insertions(+) create mode 100644 docs/source/en/model_doc/gpt_neox_japanese.mdx create mode 100644 src/transformers/models/gpt_neox_japanese/__init__.py create mode 100644 src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py create mode 100755 src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py create mode 100644 src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py create mode 100644 tests/models/gpt_neox_japanese/__init__.py create mode 100644 tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py create mode 100644 tests/models/gpt_neox_japanese/test_tokenization_gpt_neox_japanese.py diff --git a/README.md b/README.md index 570f12ac44e0..abadb9f57406 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach +1. **[GPT NeoX Japanese](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox_japanese)** (from ABEJA) released by Shinya Otani, Takayoshi Makabe, Anuj Arora, and Kyo Hattori. 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)** (from UCSD, NVIDIA) released with the paper [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) by Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang. diff --git a/README_ko.md b/README_ko.md index c6016624a861..2a42664e20d7 100644 --- a/README_ko.md +++ b/README_ko.md @@ -257,6 +257,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach +1. **[GPT NeoX Japanese](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox_japanese)** (from ABEJA) released by Shinya Otani, Takayoshi Makabe, Anuj Arora, and Kyo Hattori. 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)** (from UCSD, NVIDIA) released with the paper [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) by Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang. diff --git a/README_zh-hans.md b/README_zh-hans.md index f3c07bfb361d..7d0642cc9abf 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -281,6 +281,7 @@ conda install -c huggingface transformers 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (来自 OpenAI) 伴随论文 [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) 由 Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever 发布。 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (来自 EleutherAI) 随仓库 [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) 发布。作者为 Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy 发布。 1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach +1. **[GPT NeoX Japanese](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox_japanese)** (来自 ABEJA) 由 Shinya Otani, Takayoshi Makabe, Anuj Arora, Kyo Hattori。 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (来自 OpenAI) 伴随论文 [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) 由 Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever** 发布。 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (来自 EleutherAI) 伴随论文 [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) 由 Ben Wang and Aran Komatsuzaki 发布。 1. **[GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)** (来自 UCSD, NVIDIA) 伴随论文 [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) 由 Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 2ef861d05923..c4de5181002d 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -293,6 +293,7 @@ conda install -c huggingface transformers 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach +1. **[GPT NeoX Japanese](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox_japanese)** (from ABEJA) released by Shinya Otani, Takayoshi Makabe, Anuj Arora, and Kyo Hattori. 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released with the paper [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)** (from UCSD, NVIDIA) released with the paper [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) by Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b59fdfbc46d9..c21388b60a6f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -253,6 +253,8 @@ title: GPT Neo - local: model_doc/gpt_neox title: GPT NeoX + - local: model_doc/gpt_neox_japanese + title: GPT NeoX Japanese - local: model_doc/gptj title: GPT-J - local: model_doc/gpt2 diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 265fe39f25fd..f118359bc57b 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -97,6 +97,7 @@ The documentation is organized into five sections: 1. **[GPT](model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT Neo](model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[GPT NeoX](model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach +1. **[GPT NeoX Japanese](model_doc/gpt_neox_japanese)** (from ABEJA) released by Shinya Otani, Takayoshi Makabe, Anuj Arora, and Kyo Hattori. 1. **[GPT-2](model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GroupViT](model_doc/groupvit)** (from UCSD, NVIDIA) released with the paper [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) by Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang. @@ -242,6 +243,7 @@ Flax), PyTorch, and/or TensorFlow. | GLPN | ❌ | ❌ | ✅ | ❌ | ❌ | | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | | GPT NeoX | ❌ | ✅ | ✅ | ❌ | ❌ | +| GPT NeoX Japanese | ✅ | ❌ | ✅ | ❌ | ❌ | | GPT-J | ❌ | ❌ | ✅ | ✅ | ✅ | | GroupViT | ❌ | ❌ | ✅ | ❌ | ❌ | | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/gpt_neox_japanese.mdx b/docs/source/en/model_doc/gpt_neox_japanese.mdx new file mode 100644 index 000000000000..da94b7497603 --- /dev/null +++ b/docs/source/en/model_doc/gpt_neox_japanese.mdx @@ -0,0 +1,66 @@ + + +# GPT-NeoX-Japanese + +## Overview + +We introduce GPT-NeoX-Japanese, which is an autoregressive language model for Japanese, trained on top of [https://github.com/EleutherAI/gpt-neox](https://github.com/EleutherAI/gpt-neox). +Japanese is a unique language with its large vocabulary and a combination of hiragana, katakana, and kanji writing scripts. +To address this distinct structure of the Japanese language, we use a [special sub-word tokenizer](https://github.com/tanreinama/Japanese-BPEEncoder_V2). We are very grateful to *tanreinama* for open-sourcing this incredibly helpful tokenizer. +Following the recommendations from Google's research on [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html), we have removed bias parameters from transformer blocks, achieving better model performance. Please refer [this article](https://medium.com/ml-abeja/training-a-better-gpt-2-93b157662ae4) in detail. + +Development of the model was led by [Shinya Otani](https://github.com/SO0529), [Takayoshi Makabe](https://github.com/spider-man-tm), [Anuj Arora](https://github.com/Anuj040), and [Kyo Hattori](https://github.com/go5paopao) from [ABEJA, Inc.](https://www.abejainc.com/). For more information on this model-building activity, please refer [here (ja)](https://tech-blog.abeja.asia/entry/abeja-gpt-project-202207). + +### Generation + +The `generate()` method can be used to generate text using GPT NeoX Japanese model. + +```python +>>> from transformers import GPTNeoXJapaneseForCausalLM, GPTNeoXJapaneseTokenizer + +>>> model = GPTNeoXJapaneseForCausalLM.from_pretrained("abeja/gpt-neox-japanese-2.7b") +>>> tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + +>>> prompt = "人とAIが協調するためには、" + +>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +>>> gen_tokens = model.generate( +... input_ids, +... do_sample=True, +... temperature=0.9, +... max_length=100, +... ) +>>> gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0] + +>>> print(gen_text) +人とAIが協調するためには、AIと人が共存し、AIを正しく理解する必要があります。 +``` + +## GPTNeoXJapaneseConfig + +[[autodoc]] GPTNeoXJapaneseConfig + +## GPTNeoXJapaneseTokenizer + +[[autodoc]] GPTNeoXJapaneseTokenizer + +## GPTNeoXJapaneseModel + +[[autodoc]] GPTNeoXJapaneseModel + - forward + +## GPTNeoXJapaneseForCausalLM + +[[autodoc]] GPTNeoXJapaneseForCausalLM + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2671b37d8ebd..c6fd12595edf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -225,6 +225,7 @@ "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"], + "models.gpt_neox_japanese": ["GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXJapaneseConfig"], "models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], "models.groupvit": [ "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -558,6 +559,7 @@ _import_structure["models.funnel"].append("FunnelTokenizerFast") _import_structure["models.gpt2"].append("GPT2TokenizerFast") _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast") + _import_structure["models.gpt_neox_japanese"].append("GPTNeoXJapaneseTokenizer") _import_structure["models.herbert"].append("HerbertTokenizerFast") _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") _import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast") @@ -1291,6 +1293,15 @@ "GPTNeoXPreTrainedModel", ] ) + _import_structure["models.gpt_neox_japanese"].extend( + [ + "GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTNeoXJapaneseForCausalLM", + "GPTNeoXJapaneseLayer", + "GPTNeoXJapaneseModel", + "GPTNeoXJapanesePreTrainedModel", + ] + ) _import_structure["models.gptj"].extend( [ "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3114,6 +3125,7 @@ from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig + from .models.gpt_neox_japanese import GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXJapaneseConfig from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig from .models.groupvit import ( GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -3412,6 +3424,7 @@ from .models.funnel import FunnelTokenizerFast from .models.gpt2 import GPT2TokenizerFast from .models.gpt_neox import GPTNeoXTokenizerFast + from .models.gpt_neox_japanese import GPTNeoXJapaneseTokenizer from .models.herbert import HerbertTokenizerFast from .models.layoutlm import LayoutLMTokenizerFast from .models.layoutlmv2 import LayoutLMv2TokenizerFast @@ -4008,6 +4021,13 @@ GPTNeoXModel, GPTNeoXPreTrainedModel, ) + from .models.gpt_neox_japanese import ( + GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTNeoXJapaneseForCausalLM, + GPTNeoXJapaneseLayer, + GPTNeoXJapaneseModel, + GPTNeoXJapanesePreTrainedModel, + ) from .models.gptj import ( GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, GPTJForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 6a206bb96842..fbdbfd579cb9 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -68,6 +68,7 @@ gpt2, gpt_neo, gpt_neox, + gpt_neox_japanese, gptj, groupvit, herbert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index ae0e88bd4a1e..1204e6608a76 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -72,6 +72,7 @@ ("gpt2", "GPT2Config"), ("gpt_neo", "GPTNeoConfig"), ("gpt_neox", "GPTNeoXConfig"), + ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"), ("gptj", "GPTJConfig"), ("groupvit", "GroupViTConfig"), ("hubert", "HubertConfig"), @@ -201,6 +202,7 @@ ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neox", "GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt_neox_japanese", "GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("groupvit", "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -331,6 +333,7 @@ ("gpt2", "OpenAI GPT-2"), ("gpt_neo", "GPT Neo"), ("gpt_neox", "GPT NeoX"), + ("gpt_neox_japanese", "GPT NeoX Japanese"), ("gptj", "GPT-J"), ("groupvit", "GroupViT"), ("herbert", "HerBERT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 9edfae0c89be..caa4c9d4dffd 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -71,6 +71,7 @@ ("gpt2", "GPT2Model"), ("gpt_neo", "GPTNeoModel"), ("gpt_neox", "GPTNeoXModel"), + ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), ("gptj", "GPTJModel"), ("groupvit", "GroupViTModel"), ("hubert", "HubertModel"), @@ -234,6 +235,7 @@ ("gpt2", "GPT2LMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), + ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("ibert", "IBertForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), @@ -292,6 +294,7 @@ ("gpt2", "GPT2LMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), + ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("marian", "MarianForCausalLM"), ("mbart", "MBartForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7aa7627cab2c..97e048885e18 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -129,6 +129,7 @@ ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/gpt_neox_japanese/__init__.py b/src/transformers/models/gpt_neox_japanese/__init__.py new file mode 100644 index 000000000000..0d18143c0f02 --- /dev/null +++ b/src/transformers/models/gpt_neox_japanese/__init__.py @@ -0,0 +1,66 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable + + +_import_structure = { + "configuration_gpt_neox_japanese": ["GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXJapaneseConfig"], + "tokenization_gpt_neox_japanese": ["GPTNeoXJapaneseTokenizer"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gpt_neox_japanese"] = [ + "GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTNeoXJapaneseForCausalLM", + "GPTNeoXJapaneseLayer", + "GPTNeoXJapaneseModel", + "GPTNeoXJapanesePreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_gpt_neox_japanese import GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXJapaneseConfig + from .tokenization_gpt_neox_japanese import GPTNeoXJapaneseTokenizer + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gpt_neox_japanese import ( + GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTNeoXJapaneseForCausalLM, + GPTNeoXJapaneseLayer, + GPTNeoXJapaneseModel, + GPTNeoXJapanesePreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py new file mode 100644 index 000000000000..749a9400b961 --- /dev/null +++ b/src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPTNeoX Japanese model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "abeja/gpt-neox-japanese-2.7b": "https://huggingface.co/abeja/gpt-neox-japanese-2.7b/resolve/main/config.json", +} + + +class GPTNeoXJapaneseConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GPTNeoXModelJapanese`]. It is used to instantiate + a GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPTNeoXJapanese + [abeja/gpt-neox-japanese-2.7b](https://huggingface.co/abeja/gpt-neox-japanese-2.7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Default configs is set as 2.7B model + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the GPTNeoXJapanese model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`GPTNeoXJapanese`]. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_multiple_size (`int`, *optional*, defaults to 4): + Dimension of the "intermediate" layer in the Transformer encoder is calculated by hidden_size * + intermediate_multiple_size. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + rotary_pct (`float`, *optional*, defaults to 1.00): + percentage of hidden dimensions to allocate to rotary embeddings + rotary_emb_base (`int`, *optional*, defaults to 10000) + base for computing rotary embeddings frequency + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + weight_tying (`bool`, *optional*, defaults to `True`): + Whhether or not use weight tying between input and output embedding weight + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden layer. + Example: + + ```python + >>> from transformers import GPTNeoXJapaneseModel, GPTNeoXJapaneseConfig + + >>> # Initializing a GPTNeoXJapanese gpt-neox-japanese-2.7b style configuration + >>> configuration = GPTNeoXJapaneseConfig() + + >>> # Initializing a model from the gpt-neox-japanese-2.7b style configuration + >>> model = GPTNeoXJapaneseModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "gpt_neox_japanese" + + def __init__( + self, + vocab_size=32000, + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=32, + intermediate_multiple_size=4, + hidden_act="gelu", + rotary_pct=1.00, + rotary_emb_base=10000, + max_position_embeddings=2048, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + bos_token_id=31996, + eos_token_id=31999, + weight_tying=True, + attention_dropout=0.1, + hidden_dropout=0.0, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_multiple_size = intermediate_multiple_size + self.hidden_act = hidden_act + self.rotary_pct = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.weight_tying = weight_tying + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py new file mode 100755 index 000000000000..b79ef4f1e41b --- /dev/null +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -0,0 +1,724 @@ +# coding=utf-8 +# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPTNeoX model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "abeja/gpt-neox-japanese-2.7b" +_CONFIG_FOR_DOC = "GPTNeoXJapaneseConfig" +_TOKENIZER_FOR_DOC = "GPTNeoXJapaneseTokenizer" + +GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST = { + "https://huggingface.co/abeja/gpt-neox-japanese-2.7b/resolve/main/config.json", + # See all GPTNeoXJapanese models at https://huggingface.co/models?filter=gpt_neox_japanese +} + + +class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTNeoXJapaneseConfig + base_model_prefix = "gpt_neox_japanese" + supports_gradient_checkpointing = True + _no_split_modules = ["GPTNeoXJapaneseLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTNeoXJapaneseModel): + module.gradient_checkpointing = value + + +class GPTNeoXJapaneseAttention(nn.Module): + def __init__(self, config, use_bias=False): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + + self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + ) + self.max_positions = config.max_position_embeddings + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + + self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # Activate bias if the last layer + self.use_bias = use_bias + self.dense_bias = nn.Parameter(torch.zeros(config.hidden_size)) if use_bias else None + + def forward( + self, + hidden_states, + attention_mask, + head_mask=None, + layer_past=None, + use_cache=False, + output_attentions=False, + ): + has_layer_past = layer_past is not None and layer_past[0].numel() > 0 + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + offset = 0 + if has_layer_past: + offset = layer_past[0].shape[-2] + seq_len += offset + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs, self.dense_bias + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _create_casual_mask(self, key_length, query_length): + casual_mask = torch.tril( + torch.ones((self.max_positions, self.max_positions), dtype=torch.uint8).view( + 1, 1, self.max_positions, self.max_positions + ) + ) + return casual_mask[:, :, key_length - query_length : key_length, :key_length].bool() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + causal_mask = self._create_casual_mask(key_length, query_length) + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + causal_mask = causal_mask.to(attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = self.attention_dropout(attn_weights) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos = cos[..., offset : q.shape[-2] + offset, :] + sin = sin[..., offset : q.shape[-2] + offset, :] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def bias_dropout_add(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float, training: bool) -> Tensor: + """add bias to x, apply dropout and residual connection + + Args: + x (Tensor): main path of output + bias (Tensor): None or attn_bias of the last attention layer + residual (Optional[Tensor]): residual value + prob (float): dropout probability + training (bool): whether in training mode or not + + Returns: + Tensor: dropout(x + bias) + residual + """ + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + if residual is not None: + out = residual + out + return out + + +class GPTNeoXJapaneseMLP(nn.Module): + def __init__(self, config): + super().__init__() + intermediate_size = int(config.hidden_size * config.intermediate_multiple_size) + self.dense_h_to_4h = nn.Linear(config.hidden_size, intermediate_size, bias=False) + # Project back to h. + self.dense_4h_to_h = nn.Linear(intermediate_size, config.hidden_size, bias=False) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + intermediate = self.dense_h_to_4h(hidden_states) + intermediate = self.act(intermediate) + output = self.dense_4h_to_h(intermediate) + return output + + +class GPTNeoXJapaneseLayer(nn.Module): + def __init__(self, config, layer_number): + super().__init__() + self.layer_number = layer_number + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # activate bias only last layer + self.attention = GPTNeoXJapaneseAttention(config=config, use_bias=layer_number == config.num_hidden_layers - 1) + self.mlp = GPTNeoXJapaneseMLP(config) + self.hidden_dropout = config.hidden_dropout + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + use_cache=False, + layer_past=None, + output_attentions=False, + ): + residual = hidden_states + ln_out = self.input_layernorm(hidden_states) + attention_layer_outputs, attn_bias = self.attention( + ln_out, + attention_mask=attention_mask, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) + outputs = attention_layer_outputs[1:] + + # attn_output = (atten_output + bias) + residual + attn_output = bias_dropout_add( + attn_output, + bias=attn_bias.expand_as(residual) if attn_bias is not None else attn_bias, + residual=residual, + prob=self.hidden_dropout, + training=self.training, + ) + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + + # attn_output = (mlp_output + mlp_bias) + atten_output + attn_output = bias_dropout_add( + mlp_output, bias=None, residual=attn_output, prob=self.hidden_dropout, training=self.training + ) + + if use_cache: + outputs = (attn_output,) + outputs + else: + outputs = (attn_output,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +GPT_NEOX_JAPANESE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~GPTNeoXJapaneseConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT_NEOX_JAPANESE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`GPTNeoXJapaneseTokenizer`]. + + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare GPTNeoXJapanese Model transformer outputting raw hidden-states without any specific head on top.", + GPT_NEOX_JAPANESE_START_DOCSTRING, +) +class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [GPTNeoXJapaneseLayer(config=config, layer_number=i) for i in range(config.num_hidden_layers)] + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_in + + def set_input_embeddings(self, value): + self.embed_in = value + + @add_start_docstrings_to_model_forward(GPT_NEOX_JAPANESE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import GPTNeoXJapaneseTokenizer, GPTNeoXJapaneseModel + >>> import torch + + >>> tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> model = GPTNeoXJapaneseModel.from_pretrained("abeja/gpt-neox-japanese-2.7b") + + >>> inputs = tokenizer("日本語のGPT-neoxがHugging Faceで使えます😀", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values = tuple([None] * self.config.num_hidden_layers) + + # Attention mask. + if attention_mask is not None: + if not batch_size > 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + outputs = layer( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_layer_norm(hidden_states) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +@add_start_docstrings( + """GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""", + GPT_NEOX_JAPANESE_START_DOCSTRING, +) +class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): + + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.gpt_neox_japanese = GPTNeoXJapaneseModel(config) + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.embed_out + + def set_output_embeddings(self, new_embeddings): + self.embed_out = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_NEOX_JAPANESE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see + `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import GPTNeoXJapaneseTokenizer, GPTNeoXJapaneseForCausalLM, GPTNeoXJapaneseConfig + >>> import torch + + >>> tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> config = GPTNeoXJapaneseConfig.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> config.is_decoder = True + >>> model = GPTNeoXJapaneseForCausalLM.from_pretrained("abeja/gpt-neox-japanese-2.7b", config=config) + + >>> inputs = tokenizer("日本語のGPT-neoxがHugging Faceで使えます😀", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox_japanese( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past and past[0] is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py new file mode 100644 index 000000000000..a132d999a313 --- /dev/null +++ b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py @@ -0,0 +1,379 @@ +# coding=utf-8 +# Copyright 2022 ABEJA, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for GPTNeoXJapanese.""" +import collections +import json +import os +import re +from typing import TYPE_CHECKING, List, Optional, Tuple + +import numpy as np + +from ...tokenization_utils_fast import PreTrainedTokenizer +from ...utils import logging + + +if TYPE_CHECKING: + from transformers.pipelines.conversational import Conversation + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "abeja/gpt-neox-japanese-2.7b": "https://huggingface.co/abeja/gpt-neox-japanese-2.7b/resolve/main/vocab.txt", + }, + "emoji_file": { + "abeja/gpt-neox-japanese-2.7b": "https://huggingface.co/abeja/gpt-neox-japanese-2.7b/resolve/main/emoji.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "abeja/gpt-neox-japanese-2.7b": 2048, +} + + +def load_vocab_and_emoji(vocab_file, emoji_file): + """Loads a vocabulary file and emoji file into a dictionary.""" + with open(emoji_file, "r", encoding="utf-8") as f: + emoji = json.loads(f.read()) + + vocab = collections.OrderedDict() + raw_vocab = collections.OrderedDict() + ids_to_tokens = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as f: + token = f.readlines() + token = [[t.rstrip("\n")] if (t == "," or "," not in t) else t.rstrip("\n").split(",") for t in token] + for idx, b in enumerate(token): + ids_to_tokens[idx] = b + raw_vocab[",".join(b)] = idx + for wd in b: + vocab[wd] = idx + + return vocab, raw_vocab, ids_to_tokens, emoji + + +class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer): + """ + This tokenizer inherits from [`PreTrainedTokenizer`] and is based on Japanese special Sub-Word-Encoding that is + used in this repository (https://github.com/tanreinama/Japanese-BPEEncoder_V2). Check the repository for details. + Japanese has a relatively large vocabulary and there is no separation between words. Furthermore, the language is a + combination of hiragana, katakana, and kanji, and variants such as "1" and "①" are often used. In order to cope + with these, this tokenizer has the following features + - Subword-by-subword segmentation, which is intermediate between byte strings and morphological analysis. + - BPEs are created for each Kanji, Hiragana, and Katakana character, and there are no BPEs that cross character + types, such as Kanji + Hiragana or Hiragana + Katakana. + - All-byte encoding that does not require . + - Independent of UTF codes such as 2-byte and 3-byte characters + - Conversion of heterographs to the same token_id + - Emoji and Emoticon are grouped into 12 types as special tags. + + Example: + + ```python + >>> from transformers import GPTNeoXJapaneseTokenizer + + >>> tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained("abeja/gpt-neox-japanese-2.7b") + >>> # You can confirm both 慶応 and 慶應 are encoded to 17749 + >>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"] + [30014, 26883, 26638, 27228, 25, 26650, 31732, 31679, 27809, 26638, 17749, 31592, 17749, 31593, 321, 1281] + + >>> # Both 慶応 and 慶應 are decoded to 慶応 + >>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]) + '吾輩は猫である🐯。実は慶応(慶応)大学出身' + ``` + + Args: + vocab_file (`str`): + File containing the vocabulary. + emoji_file (`str`): + File containing the emoji. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding + bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + do_clean_text (`bool`, *optional*, defaults to `False`): + Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + emoji_file, + unk_token="<|endoftext|>", + pad_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + do_clean_text=False, + **kwargs + ): + super().__init__( + unk_token=unk_token, + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token, + do_clean_text=do_clean_text, + **kwargs, + ) + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = GPTNeoXJapaneseokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + if not os.path.isfile(emoji_file): + raise ValueError( + f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google" + " pretrained model use `tokenizer = GPTNeoXJapaneseokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.do_clean_text = do_clean_text + self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file) + self.subword_tokenizer = SubWordJapaneseTokenizer( + vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji + ) + + @property + def vocab_size(self): + # self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab + return len(self.raw_vocab) + + def get_vocab(self): + return dict(self.raw_vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.subword_tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = "".join(tokens).strip() + return out_string + + def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: + """This corresponds to DialoGPT variants of models.""" + input_ids = [] + for is_user, text in conversation.iter_texts(): + input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) + + if len(input_ids) > self.model_max_length: + input_ids = input_ids[-self.model_max_length :] + return input_ids + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + emoji_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"] + ) + else: + vocab_file = ( + (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"] + ) + emoji_file = ( + (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"] + ) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token_index, token in self.ids_to_tokens.items(): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(",".join(token) + "\n") + index += 1 + with open(emoji_file, "w", encoding="utf-8") as writer: + json.dump(self.emoji, writer) + return vocab_file, emoji_file + + +class SubWordJapaneseTokenizer(object): + """ + https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT Lisence according to the + original repository. + + MIT License + + Copyright (c) 2020 tanreinama + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of + the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__(self, vocab, ids_to_tokens, emoji): + self.vocab = vocab # same as swe + self.ids_to_tokens = ids_to_tokens # same as bpe + self.emoji = emoji + self.maxlen = np.max([len(w) for w in self.vocab.keys()]) + self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)") + self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*") + self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}") + self.content_repatter4 = re.compile( + r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*" + ) + self.content_repatter5 = re.compile( + r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*" + ) + self.content_repatter6 = re.compile( + r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*" + ) + keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿" + blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟" + self.content_trans1 = str.maketrans({k: "" for k in keisen + blocks}) + + def __len__(self): + return len(self.ids_to_tokens) + + def clean_text(self, content): + content = self.content_repatter1.sub("", content) + content = self.content_repatter2.sub("", content) + content = self.content_repatter3.sub("", content) + content = self.content_repatter4.sub("", content) + content = self.content_repatter5.sub("", content) + content = self.content_repatter6.sub("", content) + content = content.translate(self.content_trans1) + while "" in content: + content = content.replace("", "") + return content + + def tokenize(self, text, clean=False): + text = text.replace(" ", "") + text = text.replace(" ", "") + text = text.replace("\r\n", "
") + text = text.replace("\n", "
") + text = text.replace("\r", "
") + text = text.replace("\t", "") + text = text.replace("—", "ー") + text = text.replace("−", "ー") + for k, v in self.emoji["emoji"].items(): + if k in text: + text = text.replace(k, v) + if clean: + text = self.clean_text(text) + + def check_simbol(x): + e = x.encode() + if len(x) == 1 and len(e) == 2: + c = (int(e[0]) << 8) + int(e[1]) + if ( + (c >= 0xC2A1 and c <= 0xC2BF) + or (c >= 0xC780 and c <= 0xC783) + or (c >= 0xCAB9 and c <= 0xCBBF) + or (c >= 0xCC80 and c <= 0xCDA2) + ): + return True + return False + + def checku2e(x): + e = x.encode() + if len(x) == 1 and len(e) == 3: + c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2]) + if c >= 0xE28080 and c <= 0xE2B07F: + return True + return False + + pos = 0 + result = [] + while pos < len(text): + end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3 + candidates = [] # (token_id, token, pos) + for e in range(end, pos, -1): + wd = text[pos:e] + if wd in self.vocab: + if wd[0] == "<" and len(wd) > 2: + candidates = [(self.vocab[wd], wd, e)] + break + else: + candidates.append((self.vocab[wd], wd, e)) + if len(candidates) > 0: + # the smallest token_id is adopted + _, wd, e = sorted(candidates, key=lambda x: x[0])[0] + result.append(wd) + pos = e + else: + end = pos + 1 + wd = text[pos:end] + if check_simbol(wd): + result.append("") + elif checku2e(wd): + result.append("") + else: + for i in wd.encode("utf-8"): + result.append("<|byte%d|>" % i) + pos = end + return result + + def convert_id_to_token(self, index, breakline="\n"): + words = [] + byte_tokens = [] + word = self.ids_to_tokens[index][0] + if word[:6] == "<|byte" and word[-2:] == "|>": + byte_tokens.append(int(word[6:-2])) + else: + if len(byte_tokens) > 0: + words.append(bytearray(byte_tokens).decode("utf-8", errors="replace")) + byte_tokens = [] + if word[:7] == "<|emoji" and word[-2:] == "|>": + words.append(self.emoji["emoji_inv"][word]) + elif word == "": + words.append(" ") + elif word == "
": + words.append(breakline) + elif word == "": + words.append("\t") + elif word == "": + words.append("▀") + elif word == "": + words.append("ǀ") + elif word == "": + words.append("‖") + else: + words.append(word) + if len(byte_tokens) > 0: + words.append(bytearray(byte_tokens).decode("utf-8", errors="replace")) + text = "".join(words) + return text diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 34a4cf335274..b656cee9c89b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2355,6 +2355,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +GPT_NEOX_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class GPTNeoXJapaneseForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GPTNeoXJapaneseLayer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GPTNeoXJapaneseModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GPTNeoXJapanesePreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 755be5c48ae5..7a469bdff361 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -171,6 +171,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class GPTNeoXJapaneseTokenizer(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class HerbertTokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/models/gpt_neox_japanese/__init__.py b/tests/models/gpt_neox_japanese/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py b/tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py new file mode 100644 index 000000000000..32f118ba0606 --- /dev/null +++ b/tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch GPTNeoXJapanese model. """ + + +import unittest + +from transformers import GPTNeoXJapaneseConfig, is_torch_available +from transformers.models.gpt_neox_japanese.tokenization_gpt_neox_japanese import GPTNeoXJapaneseTokenizer +from transformers.testing_utils import require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import GPTNeoXJapaneseForCausalLM, GPTNeoXJapaneseModel + + +class GPTNeoXJapaneseModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_multiple_size=4, + hidden_act="gelu", + hidden_dropout=0.0, + attention_dropout=0.1, + weight_tying=True, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_multiple_size = intermediate_multiple_size + self.hidden_act = hidden_act + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.weight_tying = weight_tying + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, input_mask, token_labels + + def get_config(self): + return GPTNeoXJapaneseConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_multiple_size=self.intermediate_multiple_size, + hidden_act=self.hidden_act, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_dropout, + weight_tying=self.weight_tying, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def prepare_config_and_inputs_for_decoder(self): + config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs() + + config.is_decoder = True + + return config, input_ids, input_mask, token_labels + + def create_and_check_model(self, config, input_ids, input_mask): + model = GPTNeoXJapaneseModel(config=config) + model.to(torch_device) + model.eval() + _ = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder(self, config, input_ids, input_mask): + config.add_cross_attention = True + model = GPTNeoXJapaneseModel(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels): + model = GPTNeoXJapaneseForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask): + config.is_decoder = True + model = GPTNeoXJapaneseForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True) + output_from_no_past = output_from_no_past["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask, token_labels = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class GPTNeoXModelJapaneseTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (GPTNeoXJapaneseModel, GPTNeoXJapaneseForCausalLM) if is_torch_available() else () + all_generative_model_classes = (GPTNeoXJapaneseForCausalLM,) if is_torch_available() else () + test_pruning = False + test_missing_keys = False + test_model_parallel = False + test_head_masking = False + + def setUp(self): + self.model_tester = GPTNeoXJapaneseModelTester(self) + self.config_tester = ConfigTester(self, config_class=GPTNeoXJapaneseConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(config, input_ids, input_mask) + + def test_model_as_decoder(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + + def test_decoder_model_past_large_inputs(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask) + + def test_model_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + @slow + def test_generation(self): + model_id = "abeja/gpt-neox-japanese-2.7b" + + prompts = ["データサイエンティストとは、", "100年後に必要とされる会社は、", "フルリモートの環境で働くために必要なことは、", "国境の長いトンネルを抜けると", "美味しい日本食といえば、"] + + EXPECTED_OUTPUTS = [ + "データサイエンティストとは、データを分析し、ビジネスに役立つ知見を導き出す専門家のことです。", + "100年後に必要とされる会社は、「人」が中心の会社です。", + "フルリモートの環境で働くために必要なことは、「自分の時間をコントロールする」ことです。", + "国境の長いトンネルを抜けると、そこは雪国だった。", + "美味しい日本食といえば、やっぱりお寿司ですよね。", + ] + + tokenizer = GPTNeoXJapaneseTokenizer.from_pretrained(model_id) + model = GPTNeoXJapaneseForCausalLM.from_pretrained(model_id) + + predicted_outputs = [] + for prompt in prompts: + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + generated_ids = model.generate(input_ids, max_length=50) + generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + predicted_outputs += generated_string + self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS) diff --git a/tests/models/gpt_neox_japanese/test_tokenization_gpt_neox_japanese.py b/tests/models/gpt_neox_japanese/test_tokenization_gpt_neox_japanese.py new file mode 100644 index 000000000000..4af4da30a7b5 --- /dev/null +++ b/tests/models/gpt_neox_japanese/test_tokenization_gpt_neox_japanese.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import unittest + +from transformers.models.gpt_neox_japanese.tokenization_gpt_neox_japanese import ( + VOCAB_FILES_NAMES, + GPTNeoXJapaneseTokenizer, +) +from transformers.testing_utils import require_tokenizers, slow + +from ...test_tokenization_common import TokenizerTesterMixin + + +@require_tokenizers +class GPTNeoXJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = GPTNeoXJapaneseTokenizer + test_rust_tokenizer = False + from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False} + + def setUp(self): + super().setUp() + + vocab_tokens = [ + "こん", + "こんに", + "にちは", + "ばんは", + "世界,㔺界", + "、", + "。", + "
", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "<|emoji1|>", + "", + "<|startoftext|>", + "<|endoftext|>", + ] + emoji_tokens = {"emoji": {"\ud83d\ude00": "<|emoji1|>"}, "emoji_inv": {"<|emoji1|>": "\ud83d\ude00"}} # 😀 + self.special_tokens_map = {"unk_token": ""} + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + self.emoji_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["emoji_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + with open(self.emoji_file, "w") as emoji_writer: + emoji_writer.write(json.dumps(emoji_tokens)) + + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return GPTNeoXJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_input_output_texts(self, tokenizer): + input_text = "こんにちは、世界。 \nこんばんは、㔺界。😀" + output_text = "こんにちは、世界。 \nこんばんは、世界。😀" + return input_text, output_text + + def get_clean_sequence(self, tokenizer): + input_text, output_text = self.get_input_output_texts(tokenizer) + ids = tokenizer.encode(output_text, add_special_tokens=False) + text = tokenizer.decode(ids, clean_up_tokenization_spaces=False) + return text, ids + + def test_pretokenized_inputs(self): + pass # TODO add if relevant + + def test_maximum_encoding_length_pair_input(self): + pass # TODO add if relevant + + def test_maximum_encoding_length_single_input(self): + pass # TODO add if relevant + + def test_full_tokenizer(self): + tokenizer = self.get_tokenizer() + + # Testing tokenization + input_text = "こんにちは、世界。 こんばんは、㔺界。" + expected_token = ["こん", "にちは", "、", "世界", "。", "", "こん", "ばんは", "、", "㔺界", "。"] + tokens = tokenizer.tokenize(input_text) + self.assertListEqual(tokens, expected_token) + + # Testing conversion to ids without special tokens + expected_ids = [0, 2, 5, 4, 6, 8, 0, 3, 5, 4, 6] + input_ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(input_ids, expected_ids) + + # Testing conversion to ids with special tokens + input_tokens = tokens + [tokenizer.unk_token] + expected_ids = [0, 2, 5, 4, 6, 8, 0, 3, 5, 4, 6, 19] + input_ids = tokenizer.convert_tokens_to_ids(input_tokens) + self.assertListEqual(input_ids, expected_ids) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("abeja/gpt-neox-japanese-2.7b") + + ids_1 = tokenizer.encode("ありがとう。", add_special_tokens=False) + ids_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False) + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(ids_1) + encoded_pair = tokenizer.build_inputs_with_special_tokens(ids_1, ids_2) + + assert encoded_sentence == ids_1 + assert encoded_pair == ids_1 + ids_2 + + def test_conversion_reversible(self): + # Intentionally convert some words to accommodate character fluctuations unique to Japanese + pass + + def test_padding_different_model_input_name(self): + # tokenizer has no padding token + pass