Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RoFormer] Fix some issues #12397

Merged
merged 18 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/model_doc/roformer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RoFormerTokenizer
create_token_type_ids_from_sequences, save_vocabulary


RobertaTokenizerFast
RoFormerTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.RoFormerTokenizerFast
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ def is_datasets_available():
return _datasets_available


def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None


def is_psutil_available():
return importlib.util.find_spec("psutil") is not None

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
from ..t5.tokenization_t5_fast import T5TokenizerFast
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
Expand Down Expand Up @@ -232,6 +233,7 @@
ReformerTokenizerFast = None
RetriBertTokenizerFast = None
RobertaTokenizerFast = None
RoFormerTokenizerFast = None
SqueezeBertTokenizerFast = None
T5TokenizerFast = None
XLMRobertaTokenizerFast = None
Expand All @@ -245,7 +247,7 @@
TOKENIZER_MAPPING = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(RoFormerConfig, (RoFormerTokenizer, None)),
(RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/roformer/configuration_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json"
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
# See all RoFormer models at https://huggingface.co/models?filter=roformer
}

Expand All @@ -43,8 +47,9 @@ class RoFormerConfig(PretrainedConfig):
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
:class:`~transformers.TFRoFormerModel`.
embedding_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
embedding_size (:obj:`int`, `optional`, defaults to None):
Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not
provided.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Expand Down Expand Up @@ -96,7 +101,7 @@ class RoFormerConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50000,
embedding_size=768,
embedding_size=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
Expand All @@ -117,7 +122,7 @@ def __init__(
super().__init__(pad_token_id=pad_token_id, **kwargs)

self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.embedding_size = hidden_size if embedding_size is None else embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/roformer/modeling_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@

ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small",
"junnyu/roformer_chinese_base"
"junnyu/roformer_chinese_base",
"junnyu/roformer_chinese_char_small",
"junnyu/roformer_chinese_char_base",
"junnyu/roformer_small_discriminator",
"junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer
]

Expand Down Expand Up @@ -327,9 +331,9 @@ def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, val
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
query_layer
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/roformer/modeling_tf_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@

TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small",
"junnyu/roformer_chinese_base"
"junnyu/roformer_chinese_base",
"junnyu/roformer_chinese_char_small",
"junnyu/roformer_chinese_char_base",
"junnyu/roformer_small_discriminator",
"junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer
]

Expand Down
26 changes: 18 additions & 8 deletions src/transformers/models/roformer/tokenization_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,30 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
}
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"junnyu/roformer_chinese_small": 1536,
"junnyu/roformer_chinese_base": 1536,
"junnyu/roformer_chinese_char_small": 512,
"junnyu/roformer_chinese_char_base": 512,
"junnyu/roformer_small_discriminator": 128,
"junnyu/roformer_small_generator": 128,
}


PRETRAINED_INIT_CONFIGURATION = {
"junnyu/roformer_chinese_small": {"do_lower_case": True},
"junnyu/roformer_chinese_base": {"do_lower_case": True},
"junnyu/roformer_chinese_char_small": {"do_lower_case": True},
"junnyu/roformer_chinese_char_base": {"do_lower_case": True},
"junnyu/roformer_small_discriminator": {"do_lower_case": True},
"junnyu/roformer_small_generator": {"do_lower_case": True},
}


Expand Down Expand Up @@ -166,13 +181,8 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__ = d
try:
import rjieba
except ImportError:
raise ImportError(
"You need to install rjieba to use RoFormerTokenizer."
"See https://pypi.org/project/rjieba/ for installation."
)
import rjieba

self.jieba = rjieba

def get_vocab(self):
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/models/roformer/tokenization_roformer_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,30 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
}
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"junnyu/roformer_chinese_small": 1536,
"junnyu/roformer_chinese_base": 1536,
"junnyu/roformer_chinese_char_small": 512,
"junnyu/roformer_chinese_char_base": 512,
"junnyu/roformer_small_discriminator": 128,
"junnyu/roformer_small_generator": 128,
}


PRETRAINED_INIT_CONFIGURATION = {
"junnyu/roformer_chinese_small": {"do_lower_case": True},
"junnyu/roformer_chinese_base": {"do_lower_case": True},
"junnyu/roformer_chinese_char_small": {"do_lower_case": True},
"junnyu/roformer_chinese_char_base": {"do_lower_case": True},
"junnyu/roformer_small_discriminator": {"do_lower_case": True},
"junnyu/roformer_small_generator": {"do_lower_case": True},
}


Expand Down
30 changes: 15 additions & 15 deletions src/transformers/models/roformer/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,26 @@ def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[Norma
splits = []

# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
# for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
# if token in self.vocab:
# splits.append(normalized_string.slice((start, end)))
# else:
# token_list = self.normalizers.normalize_str(token).split()
# for token in token_list:
# if token:
# end = start + len(token)
# splits.append(normalized_string.slice((start, end)))
# start = end

# this code test_alignement_methods can't pass but fast (300ms)
for token in self.jieba.cut(str(normalized_string), False):
for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
if token in self.vocab:
splits.append(NormalizedString(token))
splits.append(normalized_string[start:end])
else:
token_list = self.normalizers.normalize_str(token).split()
for token in token_list:
if token:
splits.append(NormalizedString(token))
end = start + len(token)
splits.append(normalized_string[start:end])
start = end

# this code test_alignement_methods can't pass but fast (300ms)
# for token in self.jieba.cut(str(normalized_string), False):
# if token in self.vocab:
# splits.append(NormalizedString(token))
# else:
# token_list = self.normalizers.normalize_str(token).split()
# for token in token_list:
# if token:
# splits.append(NormalizedString(token))

return splits

Expand Down
11 changes: 11 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_flax_available,
is_onnx_available,
is_pandas_available,
is_rjieba_available,
is_scatter_available,
is_sentencepiece_available,
is_soundfile_availble,
Expand Down Expand Up @@ -223,6 +224,16 @@ def require_git_lfs(test_case):
return test_case


def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
"""
if not is_rjieba_available():
return unittest.skip("test requires rjieba")(test_case)
else:
return test_case


def require_onnx(test_case):
if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case)
Expand Down
25 changes: 7 additions & 18 deletions tests/test_tokenization_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import unittest

from transformers import RoFormerTokenizer, RoFormerTokenizerFast
from transformers.testing_utils import require_tokenizers
from transformers.testing_utils import require_rjieba, require_tokenizers

from .test_tokenization_common import TokenizerTesterMixin


def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None


def require_rjieba(test_case):
"""
Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
"""
if not is_rjieba_available():
return unittest.skip("test requires rjieba")(test_case)
else:
return test_case


@require_rjieba
@require_tokenizers
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -79,6 +64,10 @@ def test_rust_tokenizer(self):
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)

# due to custom pre_tokenize , char_to_token may be error
def test_alignement_methods(self):
# can't train new_tokenizer via Tokenizers lib
def test_training_new_tokenizer(self):
pass

# can't train new_tokenizer via Tokenizers lib
def test_training_new_tokenizer_with_special_tokens_change(self):
pass