From 4b6088aa7a4bcee316897d74ac68d6d4576de9cc Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 21:55:58 +0800 Subject: [PATCH 01/15] add RoFormerTokenizerFast into AutoTokenizer --- src/transformers/models/auto/tokenization_auto.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index bd2210af11ba8d..fe854b8a7cc595 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -184,6 +184,7 @@ from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast + from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast from ..t5.tokenization_t5_fast import T5TokenizerFast from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast @@ -218,6 +219,7 @@ ReformerTokenizerFast = None RetriBertTokenizerFast = None RobertaTokenizerFast = None + RoFormerTokenizerFast = None SqueezeBertTokenizerFast = None T5TokenizerFast = None XLMRobertaTokenizerFast = None @@ -230,7 +232,7 @@ TOKENIZER_MAPPING = OrderedDict( [ (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), - (RoFormerConfig, (RoFormerTokenizer, None)), + (RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)), (T5Config, (T5Tokenizer, T5TokenizerFast)), (MT5Config, (MT5Tokenizer, MT5TokenizerFast)), (MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)), From d2f2bea9cbdf8620f935c80e347fdafb350a3ffe Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 21:57:01 +0800 Subject: [PATCH 02/15] fix typo in roformer docs --- docs/source/model_doc/roformer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/model_doc/roformer.rst b/docs/source/model_doc/roformer.rst index 6ca558abea056c..21f1fe6bbe35e3 100644 --- a/docs/source/model_doc/roformer.rst +++ b/docs/source/model_doc/roformer.rst @@ -56,7 +56,7 @@ RoFormerTokenizer create_token_type_ids_from_sequences, save_vocabulary -RobertaTokenizerFast +RoFormerTokenizerFast ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.RoFormerTokenizerFast From d96a2a0ec91d7aec6b5d4c683d425f424742a615 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 22:00:33 +0800 Subject: [PATCH 03/15] make onnx export happy --- src/transformers/models/roformer/modeling_roformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 480d466b489654..8fc375d5534e4e 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -327,9 +327,9 @@ def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, val # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] sin, cos = sinusoidal_pos.chunk(2, dim=-1) # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - sin_pos = torch.repeat_interleave(sin, 2, dim=-1) + sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos) # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - cos_pos = torch.repeat_interleave(cos, 2, dim=-1) + cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos) # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as( query_layer From bae367561511ed3a19e3ac8178849a874df7a4d9 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 22:04:35 +0800 Subject: [PATCH 04/15] update RoFormerConfig embedding_size --- src/transformers/models/roformer/configuration_roformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index 24e3e2c30f168a..965922f9957f93 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -43,7 +43,7 @@ class RoFormerConfig(PretrainedConfig): Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or :class:`~transformers.TFRoFormerModel`. - embedding_size (:obj:`int`, `optional`, defaults to 768): + embedding_size (:obj:`int`, `optional`, defaults to None): Dimensionality of the encoder layers and the pooler layer. hidden_size (:obj:`int`, `optional`, defaults to 768): Dimension of the encoder layers and the pooler layer. @@ -96,7 +96,7 @@ class RoFormerConfig(PretrainedConfig): def __init__( self, vocab_size=50000, - embedding_size=768, + embedding_size=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, @@ -117,7 +117,7 @@ def __init__( super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size - self.embedding_size = embedding_size + self.embedding_size = hidden_size if embedding_size is None else embedding_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads From 8175b1a1fbdb674dc00de7bb66d634551ba28873 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 22:12:10 +0800 Subject: [PATCH 05/15] use jieba not rjieba --- .../models/roformer/tokenization_roformer.py | 29 +++---------------- .../models/roformer/tokenization_utils.py | 13 ++------- tests/test_tokenization_roformer.py | 12 ++++---- 3 files changed, 13 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index efb5d83051f9b3..9e5bd35c3f296a 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -16,6 +16,7 @@ import collections import os +import jieba from typing import List, Optional, Tuple from ...tokenization_utils import PreTrainedTokenizer @@ -45,7 +46,7 @@ class RoFormerTokenizer(PreTrainedTokenizer): r""" - Construct a RoFormer tokenizer. Based on `Rust Jieba `. + Construct a RoFormer tokenizer. Based on `Jieba `. This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. @@ -142,14 +143,7 @@ def __init__( strip_accents=strip_accents, ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) - try: - import rjieba - except ImportError: - raise ImportError( - "You need to install rjieba to use RoFormerTokenizer." - "See https://pypi.org/project/rjieba/ for installation." - ) - self.jieba = rjieba + @property def do_lower_case(self): @@ -159,21 +153,6 @@ def do_lower_case(self): def vocab_size(self): return len(self.vocab) - def __getstate__(self): - state = self.__dict__.copy() - state["jieba"] = None - return state - - def __setstate__(self, d): - self.__dict__ = d - try: - import rjieba - except ImportError: - raise ImportError( - "You need to install rjieba to use RoFormerTokenizer." - "See https://pypi.org/project/rjieba/ for installation." - ) - self.jieba = rjieba def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) @@ -181,7 +160,7 @@ def get_vocab(self): def _tokenize(self, text, use_jieba=True): split_tokens = [] if use_jieba: - for wholword in self.jieba.cut(text, False): + for wholword in jieba.cut(text, HMM=False): if wholword in self.vocab: split_tokens.append(wholword) else: diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py index d956d5214cb3ee..70bc240fc6a7bf 100644 --- a/src/transformers/models/roformer/tokenization_utils.py +++ b/src/transformers/models/roformer/tokenization_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """Tokenization utils for RoFormer.""" +import jieba from typing import List from tokenizers import NormalizedString, PreTokenizedString, normalizers @@ -28,20 +29,12 @@ def __init__(self, vocab) -> None: strip_accents=False, lowercase=False, ) - try: - import rjieba - except ImportError: - raise ImportError( - "You need to install rjieba to use RoFormerTokenizer." - "See https://pypi.org/project/rjieba/ for installation." - ) - self.jieba = rjieba def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: splits = [] # this code slice normalized_string is too slow (6s) but test_alignement_methods can pass - # for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False): + # for token, start, end in jieba.tokenize(str(normalized_string), hmm=False): # if token in self.vocab: # splits.append(normalized_string.slice((start, end))) # else: @@ -53,7 +46,7 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma # start = end # this code test_alignement_methods can't pass but fast (300ms) - for token in self.jieba.cut(str(normalized_string), False): + for token in jieba.cut(str(normalized_string), False): if token in self.vocab: splits.append(NormalizedString(token)) else: diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py index 19c7fb65431e1c..22111d4f279e91 100644 --- a/tests/test_tokenization_roformer.py +++ b/tests/test_tokenization_roformer.py @@ -22,21 +22,21 @@ from .test_tokenization_common import TokenizerTesterMixin -def is_rjieba_available(): - return importlib.util.find_spec("rjieba") is not None +def is_jieba_available(): + return importlib.util.find_spec("jieba") is not None -def require_rjieba(test_case): +def require_jieba(test_case): """ Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed. """ - if not is_rjieba_available(): - return unittest.skip("test requires rjieba")(test_case) + if not is_jieba_available(): + return unittest.skip("test requires jieba")(test_case) else: return test_case -@require_rjieba +@require_jieba @require_tokenizers class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): From c0daacd694321c7e77a123998f12306840c665d1 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 22:15:53 +0800 Subject: [PATCH 06/15] fix 12244 and make test_alignement passed --- .../models/roformer/tokenization_utils.py | 30 +++++++++---------- tests/test_tokenization_roformer.py | 3 -- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py index 70bc240fc6a7bf..e2bc705ea1e400 100644 --- a/src/transformers/models/roformer/tokenization_utils.py +++ b/src/transformers/models/roformer/tokenization_utils.py @@ -34,26 +34,26 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma splits = [] # this code slice normalized_string is too slow (6s) but test_alignement_methods can pass - # for token, start, end in jieba.tokenize(str(normalized_string), hmm=False): - # if token in self.vocab: - # splits.append(normalized_string.slice((start, end))) - # else: - # token_list = self.normalizers.normalize_str(token).split() - # for token in token_list: - # if token: - # end = start + len(token) - # splits.append(normalized_string.slice((start, end))) - # start = end - - # this code test_alignement_methods can't pass but fast (300ms) - for token in jieba.cut(str(normalized_string), False): + for token, start, end in jieba.tokenize(str(normalized_string), HMM=False): if token in self.vocab: - splits.append(NormalizedString(token)) + splits.append(normalized_string[start:end]) else: token_list = self.normalizers.normalize_str(token).split() for token in token_list: if token: - splits.append(NormalizedString(token)) + end = start + len(token) + splits.append(normalized_string[start:end]) + start = end + + # this code test_alignement_methods can't pass but fast (300ms) + # for token in jieba.cut(str(normalized_string), HMM=False): + # if token in self.vocab: + # splits.append(NormalizedString(token)) + # else: + # token_list = self.normalizers.normalize_str(token).split() + # for token in token_list: + # if token: + # splits.append(NormalizedString(token)) return splits diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py index 22111d4f279e91..b3462756e15e58 100644 --- a/tests/test_tokenization_roformer.py +++ b/tests/test_tokenization_roformer.py @@ -79,6 +79,3 @@ def test_rust_tokenizer(self): exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) - # due to custom pre_tokenize , char_to_token may be error - def test_alignement_methods(self): - pass From 72c3095f1b05c9ca77e150cce917db9c9db23f96 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 22:31:43 +0800 Subject: [PATCH 07/15] update ARCHIVE_MAP --- .../models/roformer/configuration_roformer.py | 6 +++++- src/transformers/models/roformer/modeling_roformer.py | 6 +++++- .../models/roformer/modeling_tf_roformer.py | 6 +++++- .../models/roformer/tokenization_roformer.py | 10 +++++++++- .../models/roformer/tokenization_roformer_fast.py | 10 +++++++++- 5 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index 965922f9957f93..53d00e6f57357a 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -22,7 +22,11 @@ ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json", - "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json" + "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json", + "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json", + "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json", + "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json", + "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json", # See all RoFormer models at https://huggingface.co/models?filter=roformer } diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 8fc375d5534e4e..ded6e0ce067519 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -60,7 +60,11 @@ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "junnyu/roformer_chinese_small", - "junnyu/roformer_chinese_base" + "junnyu/roformer_chinese_base", + "junnyu/roformer_chinese_char_small", + "junnyu/roformer_chinese_char_base", + "junnyu/roformer_small_discriminator", + "junnyu/roformer_small_generator" # See all RoFormer models at https://huggingface.co/models?filter=roformer ] diff --git a/src/transformers/models/roformer/modeling_tf_roformer.py b/src/transformers/models/roformer/modeling_tf_roformer.py index dae6e180b11b46..436acdbd30d2f2 100644 --- a/src/transformers/models/roformer/modeling_tf_roformer.py +++ b/src/transformers/models/roformer/modeling_tf_roformer.py @@ -65,7 +65,11 @@ TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "junnyu/roformer_chinese_small", - "junnyu/roformer_chinese_base" + "junnyu/roformer_chinese_base", + "junnyu/roformer_chinese_char_small", + "junnyu/roformer_chinese_char_base", + "junnyu/roformer_small_discriminator", + "junnyu/roformer_small_generator" # See all RoFormer models at https://huggingface.co/models?filter=roformer ] diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index 9e5bd35c3f296a..a9dc7488639756 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -32,15 +32,23 @@ "vocab_file": { "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", + "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", + "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", + "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", + "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", } } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128} PRETRAINED_INIT_CONFIGURATION = { "junnyu/roformer_chinese_small": {"do_lower_case": True}, "junnyu/roformer_chinese_base": {"do_lower_case": True}, + "junnyu/roformer_chinese_char_small": {"do_lower_case": True}, + "junnyu/roformer_chinese_char_base": {"do_lower_case": True}, + "junnyu/roformer_small_discriminator": {"do_lower_case": True}, + "junnyu/roformer_small_generator": {"do_lower_case": True}, } diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py index bafd60e3f6b18f..57e5eef6cf5372 100644 --- a/src/transformers/models/roformer/tokenization_roformer_fast.py +++ b/src/transformers/models/roformer/tokenization_roformer_fast.py @@ -33,15 +33,23 @@ "vocab_file": { "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", + "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", + "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", + "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", + "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", } } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128} PRETRAINED_INIT_CONFIGURATION = { "junnyu/roformer_chinese_small": {"do_lower_case": True}, "junnyu/roformer_chinese_base": {"do_lower_case": True}, + "junnyu/roformer_chinese_char_small": {"do_lower_case": True}, + "junnyu/roformer_chinese_char_base": {"do_lower_case": True}, + "junnyu/roformer_small_discriminator": {"do_lower_case": True}, + "junnyu/roformer_small_generator": {"do_lower_case": True}, } From 3c0210640834dd8726fc08369fc69a2fa1b5b550 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 22:36:48 +0800 Subject: [PATCH 08/15] make style & quality & fixup --- .../models/roformer/tokenization_roformer.py | 14 ++++++++++---- .../models/roformer/tokenization_roformer_fast.py | 9 ++++++++- .../models/roformer/tokenization_utils.py | 2 +- tests/test_tokenization_roformer.py | 1 - 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index a9dc7488639756..646534d95fc3ff 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -16,9 +16,10 @@ import collections import os -import jieba from typing import List, Optional, Tuple +import jieba + from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab @@ -39,7 +40,14 @@ } } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "junnyu/roformer_chinese_small": 1536, + "junnyu/roformer_chinese_base": 1536, + "junnyu/roformer_chinese_char_small": 512, + "junnyu/roformer_chinese_char_base": 512, + "junnyu/roformer_small_discriminator": 128, + "junnyu/roformer_small_generator": 128, +} PRETRAINED_INIT_CONFIGURATION = { @@ -152,7 +160,6 @@ def __init__( ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) - @property def do_lower_case(self): return self.basic_tokenizer.do_lower_case @@ -161,7 +168,6 @@ def do_lower_case(self): def vocab_size(self): return len(self.vocab) - def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py index 57e5eef6cf5372..9d6be92d9bf9fa 100644 --- a/src/transformers/models/roformer/tokenization_roformer_fast.py +++ b/src/transformers/models/roformer/tokenization_roformer_fast.py @@ -40,7 +40,14 @@ } } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "junnyu/roformer_chinese_small": 1536, + "junnyu/roformer_chinese_base": 1536, + "junnyu/roformer_chinese_char_small": 512, + "junnyu/roformer_chinese_char_base": 512, + "junnyu/roformer_small_discriminator": 128, + "junnyu/roformer_small_generator": 128, +} PRETRAINED_INIT_CONFIGURATION = { diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py index e2bc705ea1e400..056d0025eb13b5 100644 --- a/src/transformers/models/roformer/tokenization_utils.py +++ b/src/transformers/models/roformer/tokenization_utils.py @@ -14,9 +14,9 @@ # limitations under the License. """Tokenization utils for RoFormer.""" -import jieba from typing import List +import jieba from tokenizers import NormalizedString, PreTokenizedString, normalizers diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py index b3462756e15e58..bb31456f03ca58 100644 --- a/tests/test_tokenization_roformer.py +++ b/tests/test_tokenization_roformer.py @@ -78,4 +78,3 @@ def test_rust_tokenizer(self): input_tokens = tokens + [tokenizer.unk_token] exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) - From aa14de9afa6c938babd79de30750d4ec2a0c02f2 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 23:00:25 +0800 Subject: [PATCH 09/15] update --- .../models/roformer/tokenization_roformer.py | 27 ++++++++++++++++++- .../models/roformer/tokenization_utils.py | 14 +++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index 646534d95fc3ff..eb9c31c83d60e6 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -20,6 +20,7 @@ import jieba + from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab @@ -159,6 +160,14 @@ def __init__( strip_accents=strip_accents, ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) + try: + import jieba + except ImportError: + raise ImportError( + "You need to install jieba to use RoFormerTokenizer." + "See https://pypi.org/project/jieba/ for installation." + ) + self.jieba = jieba @property def do_lower_case(self): @@ -168,13 +177,29 @@ def do_lower_case(self): def vocab_size(self): return len(self.vocab) + def __getstate__(self): + state = self.__dict__.copy() + state["jieba"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + try: + import jieba + except ImportError: + raise ImportError( + "You need to install jieba to use RoFormerTokenizer." + "See https://pypi.org/project/jieba/ for installation." + ) + self.jieba = jieba + def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) def _tokenize(self, text, use_jieba=True): split_tokens = [] if use_jieba: - for wholword in jieba.cut(text, HMM=False): + for wholword in self.jieba.cut(text, HMM=False): if wholword in self.vocab: split_tokens.append(wholword) else: diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py index 056d0025eb13b5..9dc5b9c5fc8ecb 100644 --- a/src/transformers/models/roformer/tokenization_utils.py +++ b/src/transformers/models/roformer/tokenization_utils.py @@ -16,7 +16,7 @@ from typing import List -import jieba + from tokenizers import NormalizedString, PreTokenizedString, normalizers @@ -29,12 +29,20 @@ def __init__(self, vocab) -> None: strip_accents=False, lowercase=False, ) + try: + import jieba + except ImportError: + raise ImportError( + "You need to install jieba to use RoFormerTokenizer." + "See https://pypi.org/project/jieba/ for installation." + ) + self.jieba = jieba def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: splits = [] # this code slice normalized_string is too slow (6s) but test_alignement_methods can pass - for token, start, end in jieba.tokenize(str(normalized_string), HMM=False): + for token, start, end in self.jieba.tokenize(str(normalized_string), HMM=False): if token in self.vocab: splits.append(normalized_string[start:end]) else: @@ -46,7 +54,7 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma start = end # this code test_alignement_methods can't pass but fast (300ms) - # for token in jieba.cut(str(normalized_string), HMM=False): + # for token in self.jieba.cut(str(normalized_string), HMM=False): # if token in self.vocab: # splits.append(NormalizedString(token)) # else: From 23c723aaa7678512e627fc7afae4234c23d6aa91 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 23:03:05 +0800 Subject: [PATCH 10/15] make style & quality & fixup --- src/transformers/models/roformer/tokenization_roformer.py | 3 --- src/transformers/models/roformer/tokenization_utils.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index eb9c31c83d60e6..6abde913d799c6 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -18,9 +18,6 @@ import os from typing import List, Optional, Tuple -import jieba - - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py index 9dc5b9c5fc8ecb..c703ae857a04f3 100644 --- a/src/transformers/models/roformer/tokenization_utils.py +++ b/src/transformers/models/roformer/tokenization_utils.py @@ -16,7 +16,6 @@ from typing import List - from tokenizers import NormalizedString, PreTokenizedString, normalizers From bbf0e0e333b2df5387f58a8fa32fc4adee6f3aff Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 28 Jun 2021 23:59:28 +0800 Subject: [PATCH 11/15] make style quality fixup --- .../tensorflow/question-answering/utils_qa.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py index 2f8f0a60c45fe5..1157849c99100f 100644 --- a/examples/tensorflow/question-answering/utils_qa.py +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -38,7 +38,7 @@ def postprocess_qa_predictions( null_score_diff_threshold: float = 0.0, output_dir: Optional[str] = None, prefix: Optional[str] = None, - is_world_process_zero: bool = True, + log_level: Optional[int] = logging.WARNING, ): """ Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the @@ -70,8 +70,8 @@ def postprocess_qa_predictions( answers, are saved in `output_dir`. prefix (:obj:`str`, `optional`): If provided, the dictionaries mentioned above are saved with `prefix` added to their names. - is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether this process is the main process or not (used to determine if logging/saves should be done). + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) """ assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." all_start_logits, all_end_logits = predictions @@ -91,7 +91,7 @@ def postprocess_qa_predictions( scores_diff_json = collections.OrderedDict() # Logging. - logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) + logger.setLevel(log_level) logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Let's loop over all the examples! @@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search( end_n_top: int = 5, output_dir: Optional[str] = None, prefix: Optional[str] = None, - is_world_process_zero: bool = True, + log_level: Optional[int] = logging.WARNING, ): """ Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the @@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search( answers, are saved in `output_dir`. prefix (:obj:`str`, `optional`): If provided, the dictionaries mentioned above are saved with `prefix` added to their names. - is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether this process is the main process or not (used to determine if logging/saves should be done). + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) """ assert len(predictions) == 5, "`predictions` should be a tuple with five elements." start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions @@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search( scores_diff_json = collections.OrderedDict() if version_2_with_negative else None # Logging. - logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) + logger.setLevel(log_level) logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Let's loop over all the examples! @@ -413,14 +413,14 @@ def postprocess_qa_predictions_with_beam_search( output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" ) - print(f"Saving predictions to {prediction_file}.") + logger.info(f"Saving predictions to {prediction_file}.") with open(prediction_file, "w") as writer: writer.write(json.dumps(all_predictions, indent=4) + "\n") - print(f"Saving nbest_preds to {nbest_file}.") + logger.info(f"Saving nbest_preds to {nbest_file}.") with open(nbest_file, "w") as writer: writer.write(json.dumps(all_nbest_json, indent=4) + "\n") if version_2_with_negative: - print(f"Saving null_odds to {null_odds_file}.") + logger.info(f"Saving null_odds to {null_odds_file}.") with open(null_odds_file, "w") as writer: writer.write(json.dumps(scores_diff_json, indent=4) + "\n") From 90102cb56b0ed049922d21d9c6d0fef45f88c787 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Fri, 2 Jul 2021 18:20:16 +0800 Subject: [PATCH 12/15] update --- src/transformers/file_utils.py | 4 ++++ .../models/roformer/tokenization_roformer.py | 9 ++------- src/transformers/testing_utils.py | 11 +++++++++++ tests/test_tokenization_roformer.py | 17 +---------------- 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 5f522a440c61ca..66186eaaeb11d5 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -315,6 +315,10 @@ def is_datasets_available(): return _datasets_available +def is_jieba_available(): + return importlib.util.find_spec("jieba") is not None + + def is_psutil_available(): return importlib.util.find_spec("psutil") is not None diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index 6abde913d799c6..514235f33a5450 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -181,13 +181,8 @@ def __getstate__(self): def __setstate__(self, d): self.__dict__ = d - try: - import jieba - except ImportError: - raise ImportError( - "You need to install jieba to use RoFormerTokenizer." - "See https://pypi.org/project/jieba/ for installation." - ) + import jieba + self.jieba = jieba def get_vocab(self): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index d315785ed95c9c..018bef0cc57506 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -33,6 +33,7 @@ is_datasets_available, is_faiss_available, is_flax_available, + is_jieba_available, is_onnx_available, is_pandas_available, is_scatter_available, @@ -223,6 +224,16 @@ def require_git_lfs(test_case): return test_case +def require_jieba(test_case): + """ + Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed. + """ + if not is_jieba_available(): + return unittest.skip("test requires jieba")(test_case) + else: + return test_case + + def require_onnx(test_case): if not is_onnx_available(): return unittest.skip("test requires ONNX")(test_case) diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py index bb31456f03ca58..0f4cd3a1cbb626 100644 --- a/tests/test_tokenization_roformer.py +++ b/tests/test_tokenization_roformer.py @@ -13,29 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib import unittest from transformers import RoFormerTokenizer, RoFormerTokenizerFast -from transformers.testing_utils import require_tokenizers +from transformers.testing_utils import require_jieba, require_tokenizers from .test_tokenization_common import TokenizerTesterMixin -def is_jieba_available(): - return importlib.util.find_spec("jieba") is not None - - -def require_jieba(test_case): - """ - Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed. - """ - if not is_jieba_available(): - return unittest.skip("test requires jieba")(test_case) - else: - return test_case - - @require_jieba @require_tokenizers class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): From 1c5dd6a1d8ab3560cb396dd051585ac19c9bb602 Mon Sep 17 00:00:00 2001 From: yujun <50394665+JunnYu@users.noreply.github.com> Date: Mon, 5 Jul 2021 17:14:03 +0800 Subject: [PATCH 13/15] suggestion from LysandreJik Co-authored-by: Lysandre Debut --- src/transformers/models/roformer/configuration_roformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index 53d00e6f57357a..d0f82a1c996fdd 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -48,7 +48,7 @@ class RoFormerConfig(PretrainedConfig): the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or :class:`~transformers.TFRoFormerModel`. embedding_size (:obj:`int`, `optional`, defaults to None): - Dimensionality of the encoder layers and the pooler layer. + Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not provided. hidden_size (:obj:`int`, `optional`, defaults to 768): Dimension of the encoder layers and the pooler layer. num_hidden_layers (:obj:`int`, `optional`, defaults to 12): From 1f08622612825d8a8bbc985a0e1698c2aaaf5e4a Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Mon, 5 Jul 2021 18:05:41 +0800 Subject: [PATCH 14/15] make style --- src/transformers/models/roformer/configuration_roformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index d0f82a1c996fdd..945d1064a10ea8 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -48,7 +48,8 @@ class RoFormerConfig(PretrainedConfig): the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or :class:`~transformers.TFRoFormerModel`. embedding_size (:obj:`int`, `optional`, defaults to None): - Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not provided. + Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not + provided. hidden_size (:obj:`int`, `optional`, defaults to 768): Dimension of the encoder layers and the pooler layer. num_hidden_layers (:obj:`int`, `optional`, defaults to 12): From 52eb0f5de48599533684f1da833154e03cfd7d70 Mon Sep 17 00:00:00 2001 From: junnyu <573009727@qq.com> Date: Tue, 6 Jul 2021 00:11:17 +0800 Subject: [PATCH 15/15] use rjieba --- src/transformers/file_utils.py | 4 ++-- .../models/roformer/tokenization_roformer.py | 16 ++++++++-------- .../models/roformer/tokenization_utils.py | 12 ++++++------ src/transformers/testing_utils.py | 10 +++++----- tests/test_tokenization_roformer.py | 12 ++++++++++-- 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 66186eaaeb11d5..c3717a9289d336 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -315,8 +315,8 @@ def is_datasets_available(): return _datasets_available -def is_jieba_available(): - return importlib.util.find_spec("jieba") is not None +def is_rjieba_available(): + return importlib.util.find_spec("rjieba") is not None def is_psutil_available(): diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py index 514235f33a5450..a425ec934c81cf 100644 --- a/src/transformers/models/roformer/tokenization_roformer.py +++ b/src/transformers/models/roformer/tokenization_roformer.py @@ -60,7 +60,7 @@ class RoFormerTokenizer(PreTrainedTokenizer): r""" - Construct a RoFormer tokenizer. Based on `Jieba `. + Construct a RoFormer tokenizer. Based on `Rust Jieba `. This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. @@ -158,13 +158,13 @@ def __init__( ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) try: - import jieba + import rjieba except ImportError: raise ImportError( - "You need to install jieba to use RoFormerTokenizer." - "See https://pypi.org/project/jieba/ for installation." + "You need to install rjieba to use RoFormerTokenizer." + "See https://pypi.org/project/rjieba/ for installation." ) - self.jieba = jieba + self.jieba = rjieba @property def do_lower_case(self): @@ -181,9 +181,9 @@ def __getstate__(self): def __setstate__(self, d): self.__dict__ = d - import jieba + import rjieba - self.jieba = jieba + self.jieba = rjieba def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) @@ -191,7 +191,7 @@ def get_vocab(self): def _tokenize(self, text, use_jieba=True): split_tokens = [] if use_jieba: - for wholword in self.jieba.cut(text, HMM=False): + for wholword in self.jieba.cut(text, False): if wholword in self.vocab: split_tokens.append(wholword) else: diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py index c703ae857a04f3..195e6eff2dadbe 100644 --- a/src/transformers/models/roformer/tokenization_utils.py +++ b/src/transformers/models/roformer/tokenization_utils.py @@ -29,19 +29,19 @@ def __init__(self, vocab) -> None: lowercase=False, ) try: - import jieba + import rjieba except ImportError: raise ImportError( - "You need to install jieba to use RoFormerTokenizer." - "See https://pypi.org/project/jieba/ for installation." + "You need to install rjieba to use RoFormerTokenizer." + "See https://pypi.org/project/rjieba/ for installation." ) - self.jieba = jieba + self.jieba = rjieba def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: splits = [] # this code slice normalized_string is too slow (6s) but test_alignement_methods can pass - for token, start, end in self.jieba.tokenize(str(normalized_string), HMM=False): + for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False): if token in self.vocab: splits.append(normalized_string[start:end]) else: @@ -53,7 +53,7 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma start = end # this code test_alignement_methods can't pass but fast (300ms) - # for token in self.jieba.cut(str(normalized_string), HMM=False): + # for token in self.jieba.cut(str(normalized_string), False): # if token in self.vocab: # splits.append(NormalizedString(token)) # else: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 018bef0cc57506..439cee385d1253 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -33,9 +33,9 @@ is_datasets_available, is_faiss_available, is_flax_available, - is_jieba_available, is_onnx_available, is_pandas_available, + is_rjieba_available, is_scatter_available, is_sentencepiece_available, is_soundfile_availble, @@ -224,12 +224,12 @@ def require_git_lfs(test_case): return test_case -def require_jieba(test_case): +def require_rjieba(test_case): """ - Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed. + Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. """ - if not is_jieba_available(): - return unittest.skip("test requires jieba")(test_case) + if not is_rjieba_available(): + return unittest.skip("test requires rjieba")(test_case) else: return test_case diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py index 0f4cd3a1cbb626..c5e19b66b20023 100644 --- a/tests/test_tokenization_roformer.py +++ b/tests/test_tokenization_roformer.py @@ -16,12 +16,12 @@ import unittest from transformers import RoFormerTokenizer, RoFormerTokenizerFast -from transformers.testing_utils import require_jieba, require_tokenizers +from transformers.testing_utils import require_rjieba, require_tokenizers from .test_tokenization_common import TokenizerTesterMixin -@require_jieba +@require_rjieba @require_tokenizers class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @@ -63,3 +63,11 @@ def test_rust_tokenizer(self): input_tokens = tokens + [tokenizer.unk_token] exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) + + # can't train new_tokenizer via Tokenizers lib + def test_training_new_tokenizer(self): + pass + + # can't train new_tokenizer via Tokenizers lib + def test_training_new_tokenizer_with_special_tokens_change(self): + pass