From 4b6088aa7a4bcee316897d74ac68d6d4576de9cc Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 21:55:58 +0800
Subject: [PATCH 01/15] add RoFormerTokenizerFast into AutoTokenizer
---
src/transformers/models/auto/tokenization_auto.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index bd2210af11ba8d..fe854b8a7cc595 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -184,6 +184,7 @@
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
+ from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
from ..t5.tokenization_t5_fast import T5TokenizerFast
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
@@ -218,6 +219,7 @@
ReformerTokenizerFast = None
RetriBertTokenizerFast = None
RobertaTokenizerFast = None
+ RoFormerTokenizerFast = None
SqueezeBertTokenizerFast = None
T5TokenizerFast = None
XLMRobertaTokenizerFast = None
@@ -230,7 +232,7 @@
TOKENIZER_MAPPING = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
- (RoFormerConfig, (RoFormerTokenizer, None)),
+ (RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
From d2f2bea9cbdf8620f935c80e347fdafb350a3ffe Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 21:57:01 +0800
Subject: [PATCH 02/15] fix typo in roformer docs
---
docs/source/model_doc/roformer.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/model_doc/roformer.rst b/docs/source/model_doc/roformer.rst
index 6ca558abea056c..21f1fe6bbe35e3 100644
--- a/docs/source/model_doc/roformer.rst
+++ b/docs/source/model_doc/roformer.rst
@@ -56,7 +56,7 @@ RoFormerTokenizer
create_token_type_ids_from_sequences, save_vocabulary
-RobertaTokenizerFast
+RoFormerTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RoFormerTokenizerFast
From d96a2a0ec91d7aec6b5d4c683d425f424742a615 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 22:00:33 +0800
Subject: [PATCH 03/15] make onnx export happy
---
src/transformers/models/roformer/modeling_roformer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py
index 480d466b489654..8fc375d5534e4e 100644
--- a/src/transformers/models/roformer/modeling_roformer.py
+++ b/src/transformers/models/roformer/modeling_roformer.py
@@ -327,9 +327,9 @@ def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, val
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
- sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
+ sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
- cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
+ cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
query_layer
From bae367561511ed3a19e3ac8178849a874df7a4d9 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 22:04:35 +0800
Subject: [PATCH 04/15] update RoFormerConfig embedding_size
---
src/transformers/models/roformer/configuration_roformer.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py
index 24e3e2c30f168a..965922f9957f93 100644
--- a/src/transformers/models/roformer/configuration_roformer.py
+++ b/src/transformers/models/roformer/configuration_roformer.py
@@ -43,7 +43,7 @@ class RoFormerConfig(PretrainedConfig):
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
:class:`~transformers.TFRoFormerModel`.
- embedding_size (:obj:`int`, `optional`, defaults to 768):
+ embedding_size (:obj:`int`, `optional`, defaults to None):
Dimensionality of the encoder layers and the pooler layer.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimension of the encoder layers and the pooler layer.
@@ -96,7 +96,7 @@ class RoFormerConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50000,
- embedding_size=768,
+ embedding_size=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
@@ -117,7 +117,7 @@ def __init__(
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
- self.embedding_size = embedding_size
+ self.embedding_size = hidden_size if embedding_size is None else embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
From 8175b1a1fbdb674dc00de7bb66d634551ba28873 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 22:12:10 +0800
Subject: [PATCH 05/15] use jieba not rjieba
---
.../models/roformer/tokenization_roformer.py | 29 +++----------------
.../models/roformer/tokenization_utils.py | 13 ++-------
tests/test_tokenization_roformer.py | 12 ++++----
3 files changed, 13 insertions(+), 41 deletions(-)
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index efb5d83051f9b3..9e5bd35c3f296a 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -16,6 +16,7 @@
import collections
import os
+import jieba
from typing import List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer
@@ -45,7 +46,7 @@
class RoFormerTokenizer(PreTrainedTokenizer):
r"""
- Construct a RoFormer tokenizer. Based on `Rust Jieba `.
+ Construct a RoFormer tokenizer. Based on `Jieba `.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
@@ -142,14 +143,7 @@ def __init__(
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
- try:
- import rjieba
- except ImportError:
- raise ImportError(
- "You need to install rjieba to use RoFormerTokenizer."
- "See https://pypi.org/project/rjieba/ for installation."
- )
- self.jieba = rjieba
+
@property
def do_lower_case(self):
@@ -159,21 +153,6 @@ def do_lower_case(self):
def vocab_size(self):
return len(self.vocab)
- def __getstate__(self):
- state = self.__dict__.copy()
- state["jieba"] = None
- return state
-
- def __setstate__(self, d):
- self.__dict__ = d
- try:
- import rjieba
- except ImportError:
- raise ImportError(
- "You need to install rjieba to use RoFormerTokenizer."
- "See https://pypi.org/project/rjieba/ for installation."
- )
- self.jieba = rjieba
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
@@ -181,7 +160,7 @@ def get_vocab(self):
def _tokenize(self, text, use_jieba=True):
split_tokens = []
if use_jieba:
- for wholword in self.jieba.cut(text, False):
+ for wholword in jieba.cut(text, HMM=False):
if wholword in self.vocab:
split_tokens.append(wholword)
else:
diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py
index d956d5214cb3ee..70bc240fc6a7bf 100644
--- a/src/transformers/models/roformer/tokenization_utils.py
+++ b/src/transformers/models/roformer/tokenization_utils.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Tokenization utils for RoFormer."""
+import jieba
from typing import List
from tokenizers import NormalizedString, PreTokenizedString, normalizers
@@ -28,20 +29,12 @@ def __init__(self, vocab) -> None:
strip_accents=False,
lowercase=False,
)
- try:
- import rjieba
- except ImportError:
- raise ImportError(
- "You need to install rjieba to use RoFormerTokenizer."
- "See https://pypi.org/project/rjieba/ for installation."
- )
- self.jieba = rjieba
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
splits = []
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
- # for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
+ # for token, start, end in jieba.tokenize(str(normalized_string), hmm=False):
# if token in self.vocab:
# splits.append(normalized_string.slice((start, end)))
# else:
@@ -53,7 +46,7 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma
# start = end
# this code test_alignement_methods can't pass but fast (300ms)
- for token in self.jieba.cut(str(normalized_string), False):
+ for token in jieba.cut(str(normalized_string), False):
if token in self.vocab:
splits.append(NormalizedString(token))
else:
diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py
index 19c7fb65431e1c..22111d4f279e91 100644
--- a/tests/test_tokenization_roformer.py
+++ b/tests/test_tokenization_roformer.py
@@ -22,21 +22,21 @@
from .test_tokenization_common import TokenizerTesterMixin
-def is_rjieba_available():
- return importlib.util.find_spec("rjieba") is not None
+def is_jieba_available():
+ return importlib.util.find_spec("jieba") is not None
-def require_rjieba(test_case):
+def require_jieba(test_case):
"""
Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
"""
- if not is_rjieba_available():
- return unittest.skip("test requires rjieba")(test_case)
+ if not is_jieba_available():
+ return unittest.skip("test requires jieba")(test_case)
else:
return test_case
-@require_rjieba
+@require_jieba
@require_tokenizers
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
From c0daacd694321c7e77a123998f12306840c665d1 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 22:15:53 +0800
Subject: [PATCH 06/15] fix 12244 and make test_alignement passed
---
.../models/roformer/tokenization_utils.py | 30 +++++++++----------
tests/test_tokenization_roformer.py | 3 --
2 files changed, 15 insertions(+), 18 deletions(-)
diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py
index 70bc240fc6a7bf..e2bc705ea1e400 100644
--- a/src/transformers/models/roformer/tokenization_utils.py
+++ b/src/transformers/models/roformer/tokenization_utils.py
@@ -34,26 +34,26 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma
splits = []
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
- # for token, start, end in jieba.tokenize(str(normalized_string), hmm=False):
- # if token in self.vocab:
- # splits.append(normalized_string.slice((start, end)))
- # else:
- # token_list = self.normalizers.normalize_str(token).split()
- # for token in token_list:
- # if token:
- # end = start + len(token)
- # splits.append(normalized_string.slice((start, end)))
- # start = end
-
- # this code test_alignement_methods can't pass but fast (300ms)
- for token in jieba.cut(str(normalized_string), False):
+ for token, start, end in jieba.tokenize(str(normalized_string), HMM=False):
if token in self.vocab:
- splits.append(NormalizedString(token))
+ splits.append(normalized_string[start:end])
else:
token_list = self.normalizers.normalize_str(token).split()
for token in token_list:
if token:
- splits.append(NormalizedString(token))
+ end = start + len(token)
+ splits.append(normalized_string[start:end])
+ start = end
+
+ # this code test_alignement_methods can't pass but fast (300ms)
+ # for token in jieba.cut(str(normalized_string), HMM=False):
+ # if token in self.vocab:
+ # splits.append(NormalizedString(token))
+ # else:
+ # token_list = self.normalizers.normalize_str(token).split()
+ # for token in token_list:
+ # if token:
+ # splits.append(NormalizedString(token))
return splits
diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py
index 22111d4f279e91..b3462756e15e58 100644
--- a/tests/test_tokenization_roformer.py
+++ b/tests/test_tokenization_roformer.py
@@ -79,6 +79,3 @@ def test_rust_tokenizer(self):
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
- # due to custom pre_tokenize , char_to_token may be error
- def test_alignement_methods(self):
- pass
From 72c3095f1b05c9ca77e150cce917db9c9db23f96 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 22:31:43 +0800
Subject: [PATCH 07/15] update ARCHIVE_MAP
---
.../models/roformer/configuration_roformer.py | 6 +++++-
src/transformers/models/roformer/modeling_roformer.py | 6 +++++-
.../models/roformer/modeling_tf_roformer.py | 6 +++++-
.../models/roformer/tokenization_roformer.py | 10 +++++++++-
.../models/roformer/tokenization_roformer_fast.py | 10 +++++++++-
5 files changed, 33 insertions(+), 5 deletions(-)
diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py
index 965922f9957f93..53d00e6f57357a 100644
--- a/src/transformers/models/roformer/configuration_roformer.py
+++ b/src/transformers/models/roformer/configuration_roformer.py
@@ -22,7 +22,11 @@
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
- "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json"
+ "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
+ "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
+ "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
+ "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
+ "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
# See all RoFormer models at https://huggingface.co/models?filter=roformer
}
diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py
index 8fc375d5534e4e..ded6e0ce067519 100644
--- a/src/transformers/models/roformer/modeling_roformer.py
+++ b/src/transformers/models/roformer/modeling_roformer.py
@@ -60,7 +60,11 @@
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small",
- "junnyu/roformer_chinese_base"
+ "junnyu/roformer_chinese_base",
+ "junnyu/roformer_chinese_char_small",
+ "junnyu/roformer_chinese_char_base",
+ "junnyu/roformer_small_discriminator",
+ "junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer
]
diff --git a/src/transformers/models/roformer/modeling_tf_roformer.py b/src/transformers/models/roformer/modeling_tf_roformer.py
index dae6e180b11b46..436acdbd30d2f2 100644
--- a/src/transformers/models/roformer/modeling_tf_roformer.py
+++ b/src/transformers/models/roformer/modeling_tf_roformer.py
@@ -65,7 +65,11 @@
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small",
- "junnyu/roformer_chinese_base"
+ "junnyu/roformer_chinese_base",
+ "junnyu/roformer_chinese_char_small",
+ "junnyu/roformer_chinese_char_base",
+ "junnyu/roformer_small_discriminator",
+ "junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer
]
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index 9e5bd35c3f296a..a9dc7488639756 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -32,15 +32,23 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
+ "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
+ "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
}
}
-PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128}
PRETRAINED_INIT_CONFIGURATION = {
"junnyu/roformer_chinese_small": {"do_lower_case": True},
"junnyu/roformer_chinese_base": {"do_lower_case": True},
+ "junnyu/roformer_chinese_char_small": {"do_lower_case": True},
+ "junnyu/roformer_chinese_char_base": {"do_lower_case": True},
+ "junnyu/roformer_small_discriminator": {"do_lower_case": True},
+ "junnyu/roformer_small_generator": {"do_lower_case": True},
}
diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py
index bafd60e3f6b18f..57e5eef6cf5372 100644
--- a/src/transformers/models/roformer/tokenization_roformer_fast.py
+++ b/src/transformers/models/roformer/tokenization_roformer_fast.py
@@ -33,15 +33,23 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
+ "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
+ "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
}
}
-PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128}
PRETRAINED_INIT_CONFIGURATION = {
"junnyu/roformer_chinese_small": {"do_lower_case": True},
"junnyu/roformer_chinese_base": {"do_lower_case": True},
+ "junnyu/roformer_chinese_char_small": {"do_lower_case": True},
+ "junnyu/roformer_chinese_char_base": {"do_lower_case": True},
+ "junnyu/roformer_small_discriminator": {"do_lower_case": True},
+ "junnyu/roformer_small_generator": {"do_lower_case": True},
}
From 3c0210640834dd8726fc08369fc69a2fa1b5b550 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 22:36:48 +0800
Subject: [PATCH 08/15] make style & quality & fixup
---
.../models/roformer/tokenization_roformer.py | 14 ++++++++++----
.../models/roformer/tokenization_roformer_fast.py | 9 ++++++++-
.../models/roformer/tokenization_utils.py | 2 +-
tests/test_tokenization_roformer.py | 1 -
4 files changed, 19 insertions(+), 7 deletions(-)
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index a9dc7488639756..646534d95fc3ff 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -16,9 +16,10 @@
import collections
import os
-import jieba
from typing import List, Optional, Tuple
+import jieba
+
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab
@@ -39,7 +40,14 @@
}
}
-PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "junnyu/roformer_chinese_small": 1536,
+ "junnyu/roformer_chinese_base": 1536,
+ "junnyu/roformer_chinese_char_small": 512,
+ "junnyu/roformer_chinese_char_base": 512,
+ "junnyu/roformer_small_discriminator": 128,
+ "junnyu/roformer_small_generator": 128,
+}
PRETRAINED_INIT_CONFIGURATION = {
@@ -152,7 +160,6 @@ def __init__(
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
-
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@@ -161,7 +168,6 @@ def do_lower_case(self):
def vocab_size(self):
return len(self.vocab)
-
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py
index 57e5eef6cf5372..9d6be92d9bf9fa 100644
--- a/src/transformers/models/roformer/tokenization_roformer_fast.py
+++ b/src/transformers/models/roformer/tokenization_roformer_fast.py
@@ -40,7 +40,14 @@
}
}
-PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small":512, "junnyu/roformer_chinese_char_base":512, "junnyu/roformer_small_discriminator":128, "junnyu/roformer_small_generator":128}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "junnyu/roformer_chinese_small": 1536,
+ "junnyu/roformer_chinese_base": 1536,
+ "junnyu/roformer_chinese_char_small": 512,
+ "junnyu/roformer_chinese_char_base": 512,
+ "junnyu/roformer_small_discriminator": 128,
+ "junnyu/roformer_small_generator": 128,
+}
PRETRAINED_INIT_CONFIGURATION = {
diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py
index e2bc705ea1e400..056d0025eb13b5 100644
--- a/src/transformers/models/roformer/tokenization_utils.py
+++ b/src/transformers/models/roformer/tokenization_utils.py
@@ -14,9 +14,9 @@
# limitations under the License.
"""Tokenization utils for RoFormer."""
-import jieba
from typing import List
+import jieba
from tokenizers import NormalizedString, PreTokenizedString, normalizers
diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py
index b3462756e15e58..bb31456f03ca58 100644
--- a/tests/test_tokenization_roformer.py
+++ b/tests/test_tokenization_roformer.py
@@ -78,4 +78,3 @@ def test_rust_tokenizer(self):
input_tokens = tokens + [tokenizer.unk_token]
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
-
From aa14de9afa6c938babd79de30750d4ec2a0c02f2 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 23:00:25 +0800
Subject: [PATCH 09/15] update
---
.../models/roformer/tokenization_roformer.py | 27 ++++++++++++++++++-
.../models/roformer/tokenization_utils.py | 14 +++++++---
2 files changed, 37 insertions(+), 4 deletions(-)
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index 646534d95fc3ff..eb9c31c83d60e6 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -20,6 +20,7 @@
import jieba
+
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab
@@ -159,6 +160,14 @@ def __init__(
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
+ try:
+ import jieba
+ except ImportError:
+ raise ImportError(
+ "You need to install jieba to use RoFormerTokenizer."
+ "See https://pypi.org/project/jieba/ for installation."
+ )
+ self.jieba = jieba
@property
def do_lower_case(self):
@@ -168,13 +177,29 @@ def do_lower_case(self):
def vocab_size(self):
return len(self.vocab)
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["jieba"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+ try:
+ import jieba
+ except ImportError:
+ raise ImportError(
+ "You need to install jieba to use RoFormerTokenizer."
+ "See https://pypi.org/project/jieba/ for installation."
+ )
+ self.jieba = jieba
+
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text, use_jieba=True):
split_tokens = []
if use_jieba:
- for wholword in jieba.cut(text, HMM=False):
+ for wholword in self.jieba.cut(text, HMM=False):
if wholword in self.vocab:
split_tokens.append(wholword)
else:
diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py
index 056d0025eb13b5..9dc5b9c5fc8ecb 100644
--- a/src/transformers/models/roformer/tokenization_utils.py
+++ b/src/transformers/models/roformer/tokenization_utils.py
@@ -16,7 +16,7 @@
from typing import List
-import jieba
+
from tokenizers import NormalizedString, PreTokenizedString, normalizers
@@ -29,12 +29,20 @@ def __init__(self, vocab) -> None:
strip_accents=False,
lowercase=False,
)
+ try:
+ import jieba
+ except ImportError:
+ raise ImportError(
+ "You need to install jieba to use RoFormerTokenizer."
+ "See https://pypi.org/project/jieba/ for installation."
+ )
+ self.jieba = jieba
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
splits = []
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
- for token, start, end in jieba.tokenize(str(normalized_string), HMM=False):
+ for token, start, end in self.jieba.tokenize(str(normalized_string), HMM=False):
if token in self.vocab:
splits.append(normalized_string[start:end])
else:
@@ -46,7 +54,7 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma
start = end
# this code test_alignement_methods can't pass but fast (300ms)
- # for token in jieba.cut(str(normalized_string), HMM=False):
+ # for token in self.jieba.cut(str(normalized_string), HMM=False):
# if token in self.vocab:
# splits.append(NormalizedString(token))
# else:
From 23c723aaa7678512e627fc7afae4234c23d6aa91 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 23:03:05 +0800
Subject: [PATCH 10/15] make style & quality & fixup
---
src/transformers/models/roformer/tokenization_roformer.py | 3 ---
src/transformers/models/roformer/tokenization_utils.py | 1 -
2 files changed, 4 deletions(-)
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index eb9c31c83d60e6..6abde913d799c6 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -18,9 +18,6 @@
import os
from typing import List, Optional, Tuple
-import jieba
-
-
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab
diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py
index 9dc5b9c5fc8ecb..c703ae857a04f3 100644
--- a/src/transformers/models/roformer/tokenization_utils.py
+++ b/src/transformers/models/roformer/tokenization_utils.py
@@ -16,7 +16,6 @@
from typing import List
-
from tokenizers import NormalizedString, PreTokenizedString, normalizers
From bbf0e0e333b2df5387f58a8fa32fc4adee6f3aff Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 28 Jun 2021 23:59:28 +0800
Subject: [PATCH 11/15] make style quality fixup
---
.../tensorflow/question-answering/utils_qa.py | 22 +++++++++----------
1 file changed, 11 insertions(+), 11 deletions(-)
diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py
index 2f8f0a60c45fe5..1157849c99100f 100644
--- a/examples/tensorflow/question-answering/utils_qa.py
+++ b/examples/tensorflow/question-answering/utils_qa.py
@@ -38,7 +38,7 @@ def postprocess_qa_predictions(
null_score_diff_threshold: float = 0.0,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
- is_world_process_zero: bool = True,
+ log_level: Optional[int] = logging.WARNING,
):
"""
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
@@ -70,8 +70,8 @@ def postprocess_qa_predictions(
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
- is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
- Whether this process is the main process or not (used to determine if logging/saves should be done).
+ log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
+ ``logging`` log level (e.g., ``logging.WARNING``)
"""
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
all_start_logits, all_end_logits = predictions
@@ -91,7 +91,7 @@ def postprocess_qa_predictions(
scores_diff_json = collections.OrderedDict()
# Logging.
- logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
+ logger.setLevel(log_level)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
@@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search(
end_n_top: int = 5,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
- is_world_process_zero: bool = True,
+ log_level: Optional[int] = logging.WARNING,
):
"""
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
@@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search(
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
- is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
- Whether this process is the main process or not (used to determine if logging/saves should be done).
+ log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
+ ``logging`` log level (e.g., ``logging.WARNING``)
"""
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
@@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search(
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
# Logging.
- logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
+ logger.setLevel(log_level)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
@@ -413,14 +413,14 @@ def postprocess_qa_predictions_with_beam_search(
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
)
- print(f"Saving predictions to {prediction_file}.")
+ logger.info(f"Saving predictions to {prediction_file}.")
with open(prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
- print(f"Saving nbest_preds to {nbest_file}.")
+ logger.info(f"Saving nbest_preds to {nbest_file}.")
with open(nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
- print(f"Saving null_odds to {null_odds_file}.")
+ logger.info(f"Saving null_odds to {null_odds_file}.")
with open(null_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
From 90102cb56b0ed049922d21d9c6d0fef45f88c787 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Fri, 2 Jul 2021 18:20:16 +0800
Subject: [PATCH 12/15] update
---
src/transformers/file_utils.py | 4 ++++
.../models/roformer/tokenization_roformer.py | 9 ++-------
src/transformers/testing_utils.py | 11 +++++++++++
tests/test_tokenization_roformer.py | 17 +----------------
4 files changed, 18 insertions(+), 23 deletions(-)
diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py
index 5f522a440c61ca..66186eaaeb11d5 100644
--- a/src/transformers/file_utils.py
+++ b/src/transformers/file_utils.py
@@ -315,6 +315,10 @@ def is_datasets_available():
return _datasets_available
+def is_jieba_available():
+ return importlib.util.find_spec("jieba") is not None
+
+
def is_psutil_available():
return importlib.util.find_spec("psutil") is not None
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index 6abde913d799c6..514235f33a5450 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -181,13 +181,8 @@ def __getstate__(self):
def __setstate__(self, d):
self.__dict__ = d
- try:
- import jieba
- except ImportError:
- raise ImportError(
- "You need to install jieba to use RoFormerTokenizer."
- "See https://pypi.org/project/jieba/ for installation."
- )
+ import jieba
+
self.jieba = jieba
def get_vocab(self):
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index d315785ed95c9c..018bef0cc57506 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -33,6 +33,7 @@
is_datasets_available,
is_faiss_available,
is_flax_available,
+ is_jieba_available,
is_onnx_available,
is_pandas_available,
is_scatter_available,
@@ -223,6 +224,16 @@ def require_git_lfs(test_case):
return test_case
+def require_jieba(test_case):
+ """
+ Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
+ """
+ if not is_jieba_available():
+ return unittest.skip("test requires jieba")(test_case)
+ else:
+ return test_case
+
+
def require_onnx(test_case):
if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case)
diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py
index bb31456f03ca58..0f4cd3a1cbb626 100644
--- a/tests/test_tokenization_roformer.py
+++ b/tests/test_tokenization_roformer.py
@@ -13,29 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import importlib
import unittest
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
-from transformers.testing_utils import require_tokenizers
+from transformers.testing_utils import require_jieba, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
-def is_jieba_available():
- return importlib.util.find_spec("jieba") is not None
-
-
-def require_jieba(test_case):
- """
- Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
- """
- if not is_jieba_available():
- return unittest.skip("test requires jieba")(test_case)
- else:
- return test_case
-
-
@require_jieba
@require_tokenizers
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
From 1c5dd6a1d8ab3560cb396dd051585ac19c9bb602 Mon Sep 17 00:00:00 2001
From: yujun <50394665+JunnYu@users.noreply.github.com>
Date: Mon, 5 Jul 2021 17:14:03 +0800
Subject: [PATCH 13/15] suggestion from LysandreJik
Co-authored-by: Lysandre Debut
---
src/transformers/models/roformer/configuration_roformer.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py
index 53d00e6f57357a..d0f82a1c996fdd 100644
--- a/src/transformers/models/roformer/configuration_roformer.py
+++ b/src/transformers/models/roformer/configuration_roformer.py
@@ -48,7 +48,7 @@ class RoFormerConfig(PretrainedConfig):
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
:class:`~transformers.TFRoFormerModel`.
embedding_size (:obj:`int`, `optional`, defaults to None):
- Dimensionality of the encoder layers and the pooler layer.
+ Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not provided.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
From 1f08622612825d8a8bbc985a0e1698c2aaaf5e4a Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Mon, 5 Jul 2021 18:05:41 +0800
Subject: [PATCH 14/15] make style
---
src/transformers/models/roformer/configuration_roformer.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py
index d0f82a1c996fdd..945d1064a10ea8 100644
--- a/src/transformers/models/roformer/configuration_roformer.py
+++ b/src/transformers/models/roformer/configuration_roformer.py
@@ -48,7 +48,8 @@ class RoFormerConfig(PretrainedConfig):
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
:class:`~transformers.TFRoFormerModel`.
embedding_size (:obj:`int`, `optional`, defaults to None):
- Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not provided.
+ Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not
+ provided.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
From 52eb0f5de48599533684f1da833154e03cfd7d70 Mon Sep 17 00:00:00 2001
From: junnyu <573009727@qq.com>
Date: Tue, 6 Jul 2021 00:11:17 +0800
Subject: [PATCH 15/15] use rjieba
---
src/transformers/file_utils.py | 4 ++--
.../models/roformer/tokenization_roformer.py | 16 ++++++++--------
.../models/roformer/tokenization_utils.py | 12 ++++++------
src/transformers/testing_utils.py | 10 +++++-----
tests/test_tokenization_roformer.py | 12 ++++++++++--
5 files changed, 31 insertions(+), 23 deletions(-)
diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py
index 66186eaaeb11d5..c3717a9289d336 100644
--- a/src/transformers/file_utils.py
+++ b/src/transformers/file_utils.py
@@ -315,8 +315,8 @@ def is_datasets_available():
return _datasets_available
-def is_jieba_available():
- return importlib.util.find_spec("jieba") is not None
+def is_rjieba_available():
+ return importlib.util.find_spec("rjieba") is not None
def is_psutil_available():
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index 514235f33a5450..a425ec934c81cf 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -60,7 +60,7 @@
class RoFormerTokenizer(PreTrainedTokenizer):
r"""
- Construct a RoFormer tokenizer. Based on `Jieba `.
+ Construct a RoFormer tokenizer. Based on `Rust Jieba `.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
@@ -158,13 +158,13 @@ def __init__(
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
try:
- import jieba
+ import rjieba
except ImportError:
raise ImportError(
- "You need to install jieba to use RoFormerTokenizer."
- "See https://pypi.org/project/jieba/ for installation."
+ "You need to install rjieba to use RoFormerTokenizer."
+ "See https://pypi.org/project/rjieba/ for installation."
)
- self.jieba = jieba
+ self.jieba = rjieba
@property
def do_lower_case(self):
@@ -181,9 +181,9 @@ def __getstate__(self):
def __setstate__(self, d):
self.__dict__ = d
- import jieba
+ import rjieba
- self.jieba = jieba
+ self.jieba = rjieba
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
@@ -191,7 +191,7 @@ def get_vocab(self):
def _tokenize(self, text, use_jieba=True):
split_tokens = []
if use_jieba:
- for wholword in self.jieba.cut(text, HMM=False):
+ for wholword in self.jieba.cut(text, False):
if wholword in self.vocab:
split_tokens.append(wholword)
else:
diff --git a/src/transformers/models/roformer/tokenization_utils.py b/src/transformers/models/roformer/tokenization_utils.py
index c703ae857a04f3..195e6eff2dadbe 100644
--- a/src/transformers/models/roformer/tokenization_utils.py
+++ b/src/transformers/models/roformer/tokenization_utils.py
@@ -29,19 +29,19 @@ def __init__(self, vocab) -> None:
lowercase=False,
)
try:
- import jieba
+ import rjieba
except ImportError:
raise ImportError(
- "You need to install jieba to use RoFormerTokenizer."
- "See https://pypi.org/project/jieba/ for installation."
+ "You need to install rjieba to use RoFormerTokenizer."
+ "See https://pypi.org/project/rjieba/ for installation."
)
- self.jieba = jieba
+ self.jieba = rjieba
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
splits = []
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
- for token, start, end in self.jieba.tokenize(str(normalized_string), HMM=False):
+ for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
if token in self.vocab:
splits.append(normalized_string[start:end])
else:
@@ -53,7 +53,7 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma
start = end
# this code test_alignement_methods can't pass but fast (300ms)
- # for token in self.jieba.cut(str(normalized_string), HMM=False):
+ # for token in self.jieba.cut(str(normalized_string), False):
# if token in self.vocab:
# splits.append(NormalizedString(token))
# else:
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 018bef0cc57506..439cee385d1253 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -33,9 +33,9 @@
is_datasets_available,
is_faiss_available,
is_flax_available,
- is_jieba_available,
is_onnx_available,
is_pandas_available,
+ is_rjieba_available,
is_scatter_available,
is_sentencepiece_available,
is_soundfile_availble,
@@ -224,12 +224,12 @@ def require_git_lfs(test_case):
return test_case
-def require_jieba(test_case):
+def require_rjieba(test_case):
"""
- Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
+ Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
"""
- if not is_jieba_available():
- return unittest.skip("test requires jieba")(test_case)
+ if not is_rjieba_available():
+ return unittest.skip("test requires rjieba")(test_case)
else:
return test_case
diff --git a/tests/test_tokenization_roformer.py b/tests/test_tokenization_roformer.py
index 0f4cd3a1cbb626..c5e19b66b20023 100644
--- a/tests/test_tokenization_roformer.py
+++ b/tests/test_tokenization_roformer.py
@@ -16,12 +16,12 @@
import unittest
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
-from transformers.testing_utils import require_jieba, require_tokenizers
+from transformers.testing_utils import require_rjieba, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
-@require_jieba
+@require_rjieba
@require_tokenizers
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@@ -63,3 +63,11 @@ def test_rust_tokenizer(self):
input_tokens = tokens + [tokenizer.unk_token]
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
+
+ # can't train new_tokenizer via Tokenizers lib
+ def test_training_new_tokenizer(self):
+ pass
+
+ # can't train new_tokenizer via Tokenizers lib
+ def test_training_new_tokenizer_with_special_tokens_change(self):
+ pass