diff --git a/community/iverxin/bert-base-japanese-char-whole-word-masking/README.md b/community/iverxin/bert-base-japanese-char-whole-word-masking/README.md new file mode 100644 index 000000000000..2e5839e58e8b --- /dev/null +++ b/community/iverxin/bert-base-japanese-char-whole-word-masking/README.md @@ -0,0 +1,64 @@ + + +# BERT base Japanese (character tokenization, whole word masking enabled) + +This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language. + +This version of the model processes input texts with word-level tokenization based on the IPA dictionary, followed by character-level tokenization. + +Additionally, the model is trained with the whole word masking enabled for the masked language modeling (MLM) objective. + +The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/tree/v1.0). + +## Model architecture + +The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads. + +## Training Data + +The model is trained on Japanese Wikipedia as of September 1, 2019. + +To generate the training corpus, [WikiExtractor](https://github.com/attardi/wikiextractor) is used to extract plain texts from a dump file of Wikipedia articles. + +The text files used for the training are 2.6GB in size, consisting of approximately 17M sentences. + +## Tokenization + +The texts are first tokenized by [MeCab](https://taku910.github.io/mecab/) morphological parser with the IPA dictionary and then split into characters. + +The vocabulary size is 4000. + +## Training + +The model is trained with the same configuration as the original BERT; 512 tokens per instance, 256 instances per batch, and 1M training steps. + +For the training of the MLM (masked language modeling) objective, we introduced the **Whole Word Masking** in which all of the subword tokens corresponding to a single word (tokenized by MeCab) are masked at once. + +## Licenses + +The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/). + +## Acknowledgments + +For training models, we used Cloud TPUs provided by [TensorFlow Research Cloud](https://www.tensorflow.org/tfrc/) program. + +## Usage +```python +import paddle +from paddlenlp.transformers import BertJapaneseTokenizer, BertForMaskedLM + +path = "iverxin/bert-base-japanese-char-whole-word-masking/" +tokenizer = BertJapaneseTokenizer.from_pretrained(path) +model = BertForMaskedLM.from_pretrained(path) +text1 = "こんにちは" + +model.eval() +inputs = tokenizer(text1) +inputs = {k: paddle.to_tensor([v]) for (k, v) in inputs.items()} +output = model(**inputs) +print(output.shape) +# [1, 5, 32000] +``` + +## Weights source +https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking diff --git a/community/iverxin/bert-base-japanese-char-whole-word-masking/files.json b/community/iverxin/bert-base-japanese-char-whole-word-masking/files.json new file mode 100644 index 000000000000..1c799cd70cdf --- /dev/null +++ b/community/iverxin/bert-base-japanese-char-whole-word-masking/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file":"https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char-whole-word-masking/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char-whole-word-masking/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char-whole-word-masking/tokenizer_config.pdparams", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char-whole-word-masking/vocab.txt" +} diff --git a/community/iverxin/bert-base-japanese-char/README.md b/community/iverxin/bert-base-japanese-char/README.md new file mode 100644 index 000000000000..8083d5ec774e --- /dev/null +++ b/community/iverxin/bert-base-japanese-char/README.md @@ -0,0 +1,60 @@ + + +# BERT base Japanese (character tokenization) + +This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language. + +This version of the model processes input texts with word-level tokenization based on the IPA dictionary, followed by character-level tokenization. + +The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/tree/v1.0). + +## Model architecture + +The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads. + +## Training Data + +The model is trained on Japanese Wikipedia as of September 1, 2019. + +To generate the training corpus, [WikiExtractor](https://github.com/attardi/wikiextractor) is used to extract plain texts from a dump file of Wikipedia articles. + +The text files used for the training are 2.6GB in size, consisting of approximately 17M sentences. + +## Tokenization + +The texts are first tokenized by [MeCab](https://taku910.github.io/mecab/) morphological parser with the IPA dictionary and then split into characters. + +The vocabulary size is 4000. + +## Training + +The model is trained with the same configuration as the original BERT; 512 tokens per instance, 256 instances per batch, and 1M training steps. + +## Licenses + +The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/). + +## Acknowledgments + +For training models, we used Cloud TPUs provided by [TensorFlow Research Cloud](https://www.tensorflow.org/tfrc/) program. + +## Usage +```python +import paddle +from paddlenlp.transformers import BertJapaneseTokenizer, BertForMaskedLM, MecabTokenizer + +path = "iverxin/bert-base-japanese-char/" +tokenizer = BertJapaneseTokenizer.from_pretrained(path) +model = BertForMaskedLM.from_pretrained(path) +text1 = "こんにちは" +text2 = "櫓を飛ばす" + +model.eval() +inputs = tokenizer(text1) +inputs = {k: paddle.to_tensor([v]) for (k, v) in inputs.items()} +output = model(**inputs) +print(output.shape) +``` + +## Weights source +https://huggingface.co/cl-tohoku/bert-base-japanese-char diff --git a/community/iverxin/bert-base-japanese-char/files.json b/community/iverxin/bert-base-japanese-char/files.json new file mode 100644 index 000000000000..3d10c499bda0 --- /dev/null +++ b/community/iverxin/bert-base-japanese-char/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file":"https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char/tokenizer_config.pdparams", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-char/vocab.txt" +} diff --git a/community/iverxin/bert-base-japanese-whole-word-masking/README.md b/community/iverxin/bert-base-japanese-whole-word-masking/README.md new file mode 100644 index 000000000000..0945f9c51634 --- /dev/null +++ b/community/iverxin/bert-base-japanese-whole-word-masking/README.md @@ -0,0 +1,63 @@ + + +# BERT base Japanese (IPA dictionary, whole word masking enabled) + +This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language. + +This version of the model processes input texts with word-level tokenization based on the IPA dictionary, followed by the WordPiece subword tokenization. + +Additionally, the model is trained with the whole word masking enabled for the masked language modeling (MLM) objective. + +The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/tree/v1.0). + +## Model architecture + +The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads. + +## Training Data + +The model is trained on Japanese Wikipedia as of September 1, 2019. + +To generate the training corpus, [WikiExtractor](https://github.com/attardi/wikiextractor) is used to extract plain texts from a dump file of Wikipedia articles. + +The text files used for the training are 2.6GB in size, consisting of approximately 17M sentences. + +## Tokenization + +The texts are first tokenized by [MeCab](https://taku910.github.io/mecab/) morphological parser with the IPA dictionary and then split into subwords by the WordPiece algorithm. + +The vocabulary size is 32000. + +## Training + +The model is trained with the same configuration as the original BERT; 512 tokens per instance, 256 instances per batch, and 1M training steps. + +For the training of the MLM (masked language modeling) objective, we introduced the **Whole Word Masking** in which all of the subword tokens corresponding to a single word (tokenized by MeCab) are masked at once. + +## Licenses + +The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/). + +## Acknowledgments + +For training models, we used Cloud TPUs provided by [TensorFlow Research Cloud](https://www.tensorflow.org/tfrc/) program. + +## Usage +```python +import paddle +from paddlenlp.transformers import BertJapaneseTokenizer, BertForMaskedLM + +path = "iverxin/bert-base-japanese-whole-word-masking/" +tokenizer = BertJapaneseTokenizer.from_pretrained(path) +model = BertForMaskedLM.from_pretrained(path) +text1 = "こんにちは" + +model.eval() +inputs = tokenizer(text1) +inputs = {k: paddle.to_tensor([v]) for (k, v) in inputs.items()} +output = model(**inputs) +print(output.shape) +``` + +## Weights source +https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking diff --git a/community/iverxin/bert-base-japanese-whole-word-masking/files.json b/community/iverxin/bert-base-japanese-whole-word-masking/files.json new file mode 100644 index 000000000000..04e128b55d8d --- /dev/null +++ b/community/iverxin/bert-base-japanese-whole-word-masking/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file":"https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-whole-word-masking/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-whole-word-masking/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-whole-word-masking/tokenizer_config.pdparams", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese-whole-word-masking/vocab.txt" +} diff --git a/community/iverxin/bert-base-japanese/README.md b/community/iverxin/bert-base-japanese/README.md new file mode 100644 index 000000000000..2d9089e95805 --- /dev/null +++ b/community/iverxin/bert-base-japanese/README.md @@ -0,0 +1,59 @@ +# BERT base Japanese (IPA dictionary) + +This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language. + +This version of the model processes input texts with word-level tokenization based on the IPA dictionary, followed by the WordPiece subword tokenization. + +The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/tree/v1.0). + +## Model architecture + +The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads. + +## Training Data + +The model is trained on Japanese Wikipedia as of September 1, 2019. + +To generate the training corpus, [WikiExtractor](https://github.com/attardi/wikiextractor) is used to extract plain texts from a dump file of Wikipedia articles. + +The text files used for the training are 2.6GB in size, consisting of approximately 17M sentences. + +## Tokenization + +The texts are first tokenized by [MeCab](https://taku910.github.io/mecab/) morphological parser with the IPA dictionary and then split into subwords by the WordPiece algorithm. + +The vocabulary size is 32000. + +## Training + +The model is trained with the same configuration as the original BERT; 512 tokens per instance, 256 instances per batch, and 1M training steps. + +## Licenses + +The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/). + +## Acknowledgments + +For training models, we used Cloud TPUs provided by [TensorFlow Research Cloud](https://www.tensorflow.org/tfrc/) program. + + +## Usage +```python +import paddle +from paddlenlp.transformers import BertJapaneseTokenizer, BertForMaskedLM + +path = "iverxin/bert-base-japanese/" +tokenizer = BertJapaneseTokenizer.from_pretrained(path) +model = BertForMaskedLM.from_pretrained(path) +text1 = "こんにちは" + +model.eval() +inputs = tokenizer(text1) +inputs = {k: paddle.to_tensor([v]) for (k, v) in inputs.items()} +output = model(**inputs) +print(output.shape) +``` + + +## Weights source +https://huggingface.co/cl-tohoku/bert-base-japanese diff --git a/community/iverxin/bert-base-japanese/files.json b/community/iverxin/bert-base-japanese/files.json new file mode 100644 index 000000000000..053deb857bbf --- /dev/null +++ b/community/iverxin/bert-base-japanese/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file":"https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese/tokenizer_config.pdparams", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/iverxin/bert-base-japanese/vocab.txt" +} diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 01d874b2b247..4819c7602821 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -18,6 +18,7 @@ from .bert.modeling import * from .bert.tokenizer import * +from .bert_japanese.tokenizer import * from .ernie.modeling import * from .ernie.tokenizer import * from .gpt.modeling import * diff --git a/paddlenlp/transformers/bert/tokenizer.py b/paddlenlp/transformers/bert/tokenizer.py index dac665d0c343..1086df17358b 100644 --- a/paddlenlp/transformers/bert/tokenizer.py +++ b/paddlenlp/transformers/bert/tokenizer.py @@ -14,16 +14,17 @@ # limitations under the License. import copy -import io -import json import os -import six import unicodedata from .. import PretrainedTokenizer from ..tokenizer_utils import convert_to_unicode, whitespace_tokenize, _is_whitespace, _is_control, _is_punctuation -__all__ = ['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'] +__all__ = [ + 'BasicTokenizer', + 'BertTokenizer', + 'WordpieceTokenizer', +] class BasicTokenizer(object): @@ -290,9 +291,9 @@ class BertTokenizer(PretrainedTokenizer): .. code-block:: from paddlenlp.transformers import BertTokenizer - berttokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - inputs = berttokenizer.tokenize('He was a puppeteer') + inputs = tokenizer('He was a puppeteer') print(inputs) ''' @@ -554,7 +555,7 @@ def create_token_type_ids_from_sequences(self, 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | - If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: token_ids_0 (List[int]): diff --git a/paddlenlp/transformers/bert_japanese/__init__.py b/paddlenlp/transformers/bert_japanese/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/transformers/bert_japanese/convert_bert_japanese_params.py b/paddlenlp/transformers/bert_japanese/convert_bert_japanese_params.py new file mode 100644 index 000000000000..75a63df67a49 --- /dev/null +++ b/paddlenlp/transformers/bert_japanese/convert_bert_japanese_params.py @@ -0,0 +1,69 @@ +import paddle +import torch +import numpy as np +from paddle.utils.download import get_path_from_url + +model_names = [ + "bert-base-japanese", "bert-base-japanese-whole-word-masking", + "bert-base-japanese-char", "bert-base-japanese-char-whole-word-masking" +] + +for model_name in model_names: + torch_model_url = "https://huggingface.co/cl-tohoku/%s/resolve/main/pytorch_model.bin" % model_name + torch_model_path = get_path_from_url(torch_model_url, '../bert') + torch_state_dict = torch.load(torch_model_path) + + paddle_model_path = "%s.pdparams" % model_name + paddle_state_dict = {} + + # State_dict's keys mapping: from torch to paddle + keys_dict = { + # about embeddings + "embeddings.LayerNorm.gamma": "embeddings.layer_norm.weight", + "embeddings.LayerNorm.beta": "embeddings.layer_norm.bias", + + # about encoder layer + 'encoder.layer': 'encoder.layers', + 'attention.self.query': 'self_attn.q_proj', + 'attention.self.key': 'self_attn.k_proj', + 'attention.self.value': 'self_attn.v_proj', + 'attention.output.dense': 'self_attn.out_proj', + 'attention.output.LayerNorm.gamma': 'norm1.weight', + 'attention.output.LayerNorm.beta': 'norm1.bias', + 'intermediate.dense': 'linear1', + 'output.dense': 'linear2', + 'output.LayerNorm.gamma': 'norm2.weight', + 'output.LayerNorm.beta': 'norm2.bias', + + # about cls predictions + 'cls.predictions.transform.dense': 'cls.predictions.transform', + 'cls.predictions.decoder.weight': 'cls.predictions.decoder_weight', + 'cls.predictions.transform.LayerNorm.gamma': + 'cls.predictions.layer_norm.weight', + 'cls.predictions.transform.LayerNorm.beta': + 'cls.predictions.layer_norm.bias', + 'cls.predictions.bias': 'cls.predictions.decoder_bias' + } + + for torch_key in torch_state_dict: + paddle_key = torch_key + for k in keys_dict: + if k in paddle_key: + paddle_key = paddle_key.replace(k, keys_dict[k]) + + if ('linear' in paddle_key) or ('proj' in paddle_key) or ( + 'vocab' in paddle_key and 'weight' in paddle_key) or ( + "dense.weight" in paddle_key) or ( + 'transform.weight' in paddle_key) or ( + 'seq_relationship.weight' in paddle_key): + paddle_state_dict[paddle_key] = paddle.to_tensor(torch_state_dict[ + torch_key].cpu().numpy().transpose()) + else: + paddle_state_dict[paddle_key] = paddle.to_tensor(torch_state_dict[ + torch_key].cpu().numpy()) + + print("torch: ", torch_key, "\t", torch_state_dict[torch_key].shape) + print("paddle: ", paddle_key, "\t", paddle_state_dict[paddle_key].shape, + "\n") + + paddle.save(paddle_state_dict, paddle_model_path) diff --git a/paddlenlp/transformers/bert_japanese/tokenizer.py b/paddlenlp/transformers/bert_japanese/tokenizer.py new file mode 100644 index 000000000000..db5161c1bd71 --- /dev/null +++ b/paddlenlp/transformers/bert_japanese/tokenizer.py @@ -0,0 +1,327 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import unicodedata +import collections + +from .. import BertTokenizer, BasicTokenizer, WordpieceTokenizer + +__all__ = ['BertJapaneseTokenizer', 'MecabTokenizer', 'CharacterTokenizer'] + + +class BertJapaneseTokenizer(BertTokenizer): + """ + Construct a BERT tokenizer for Japanese text, based on a MecabTokenizer. + + Args: + vocab_file (str): + The vocabulary file path (ends with '.txt') required to instantiate + a `WordpieceTokenizer`. + do_lower_case (bool, optional): + Whether or not to lowercase the input when tokenizing. + Defaults to`False`. + do_word_tokenize (bool, optional): + Whether to do word tokenization. Defaults to`True`. + do_subword_tokenize (bool, optional): + Whether to do subword tokenization. Defaults to`True`. + word_tokenizer_type (str, optional): + Type of word tokenizer. Defaults to`basic`. + subword_tokenizer_type (str, optional): + Type of subword tokenizer. Defaults to`wordpiece`. + never_split (bool, optional): + Kept for backward compatibility purposes. Defaults to`None`. + mecab_kwargs (str, optional): + Dictionary passed to the `MecabTokenizer` constructor. + unk_token (str): + A special token representing the *unknown (out-of-vocabulary)* token. + An unknown token is set to be `unk_token` inorder to be converted to an ID. + Defaults to "[UNK]". + sep_token (str): + A special token separating two different sentences in the same input. + Defaults to "[SEP]". + pad_token (str): + A special token used to make arrays of tokens the same size for batching purposes. + Defaults to "[PAD]". + cls_token (str): + A special token used for sequence classification. It is the last token + of the sequence when built with special tokens. Defaults to "[CLS]". + mask_token (str): + A special token representing a masked token. This is the token used + in the masked language modeling task which the model tries to predict the original unmasked ones. + Defaults to "[MASK]". + + + Examples: + .. code-block:: + + from paddlenlp.transformers import BertJapaneseTokenizer + tokenizer = BertJapaneseTokenizer.from_pretrained('iverxin/bert-base-japanese/') + + inputs = tokenizer('こんにちは') + print(inputs) + + ''' + {'input_ids': [2, 10350, 25746, 28450, 3], 'token_type_ids': [0, 0, 0, 0, 0]} + ''' + + """ + + def __init__(self, + vocab_file, + do_lower_case=False, + do_word_tokenize=True, + do_subword_tokenize=True, + word_tokenizer_type="basic", + subword_tokenizer_type="wordpiece", + never_split=None, + mecab_kwargs=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]"): + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the " + "vocabulary from a pretrained model please use " + "`tokenizer = BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + .format(vocab_file)) + + self.vocab = self.load_vocabulary(vocab_file, unk_token=unk_token) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.idx_to_token.items()]) + + self.do_word_tokenize = do_word_tokenize + self.word_tokenizer_type = word_tokenizer_type + self.lower_case = do_lower_case + self.never_split = never_split + self.mecab_kwargs = copy.deepcopy(mecab_kwargs) + if do_word_tokenize: + if word_tokenizer_type == "basic": + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, ) + elif word_tokenizer_type == "mecab": + self.basic_tokenizer = MecabTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + **(mecab_kwargs or {})) + else: + raise ValueError( + f"Invalid word_tokenizer_type '{word_tokenizer_type}' is specified." + ) + + self.do_subword_tokenize = do_subword_tokenize + self.subword_tokenizer_type = subword_tokenizer_type + if do_subword_tokenize: + if subword_tokenizer_type == "wordpiece": + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=unk_token) + elif subword_tokenizer_type == "character": + self.wordpiece_tokenizer = CharacterTokenizer( + vocab=self.vocab, unk_token=unk_token) + else: + raise ValueError( + f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified." + ) + + @property + def do_lower_case(self): + return self.lower_case + + def __getstate__(self): + state = dict(self.__dict__) + if self.word_tokenizer_type == "mecab": + del state["basic_tokenizer"] + return state + + def __setstate__(self, state): + self.__dict__ = state + if self.word_tokenizer_type == "mecab": + self.basic_tokenizer = MecabTokenizer( + do_lower_case=self.do_lower_case, + never_split=self.never_split, + **(self.mecab_kwargs or {})) + + def _tokenize(self, text): + if self.do_word_tokenize: + tokens = self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens) + else: + tokens = [text] + + if self.do_subword_tokenize: + split_tokens = [ + sub_token + for token in tokens + for sub_token in self.wordpiece_tokenizer.tokenize(token) + ] + else: + split_tokens = tokens + + return split_tokens + + +class MecabTokenizer: + """Runs basic tokenization with MeCab morphological parser.""" + + def __init__( + self, + do_lower_case=False, + never_split=None, + normalize_text=True, + mecab_dic="ipadic", + mecab_option=None, ): + """ + Constructs a MecabTokenizer. + + Args: + do_lower_case (bool): + Whether to lowercase the input. Defaults to`True`. + never_split: (list): + Kept for backward compatibility purposes. Defaults to`None`. + normalize_text (bool): + Whether to apply unicode normalization to text before tokenization. Defaults to`True`. + mecab_dic (string): + Name of dictionary to be used for MeCab initialization. If you are using a system-installed dictionary, + set this option to `None` and modify `mecab_option`. Defaults to`ipadic`. + mecab_option (string): + String passed to MeCab constructor. Defaults to`None`. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split if never_split is not None else [] + self.normalize_text = normalize_text + + try: + import fugashi + except ModuleNotFoundError as error: + raise error.__class__( + "You need to install fugashi to use MecabTokenizer. " + "See https://pypi.org/project/fugashi/ for installation.") + + mecab_option = mecab_option or "" + + if mecab_dic is not None: + if mecab_dic == "ipadic": + try: + import ipadic + except ModuleNotFoundError as error: + raise error.__class__( + "The ipadic dictionary is not installed. " + "See https://github.com/polm/ipadic-py for installation." + ) + + dic_dir = ipadic.DICDIR + + elif mecab_dic == "unidic_lite": + try: + import unidic_lite + except ModuleNotFoundError as error: + raise error.__class__( + "The unidic_lite dictionary is not installed. " + "See https://github.com/polm/unidic-lite for installation." + ) + + dic_dir = unidic_lite.DICDIR + + elif mecab_dic == "unidic": + try: + import unidic + except ModuleNotFoundError as error: + raise error.__class__( + "The unidic dictionary is not installed. " + "See https://github.com/polm/unidic-py for installation." + ) + + dic_dir = unidic.DICDIR + if not os.path.isdir(dic_dir): + raise RuntimeError( + "The unidic dictionary itself is not found." + "See https://github.com/polm/unidic-py for installation." + ) + else: + raise ValueError("Invalid mecab_dic is specified.") + + mecabrc = os.path.join(dic_dir, "mecabrc") + mecab_option = f'-d "{dic_dir}" -r "{mecabrc}" ' + mecab_option + + self.mecab = fugashi.GenericTagger(mecab_option) + + def tokenize(self, text, never_split=None, **kwargs): + """Tokenizes a piece of text.""" + if self.normalize_text: + text = unicodedata.normalize("NFKC", text) + + never_split = self.never_split + (never_split + if never_split is not None else []) + tokens = [] + + for word in self.mecab(text): + token = word.surface + + if self.do_lower_case and token not in never_split: + token = token.lower() + + tokens.append(token) + + return tokens + + +class CharacterTokenizer: + """Runs Character tokenization.""" + + def __init__(self, vocab, unk_token, normalize_text=True): + """ + Constructs a CharacterTokenizer. + + Args: + vocab: + Vocabulary object. + unk_token (str): + A special symbol for out-of-vocabulary token. + normalize_text (boolean): + Whether to apply unicode normalization to text before tokenization. Defaults to True. + """ + self.vocab = vocab + self.unk_token = unk_token + self.normalize_text = normalize_text + + def tokenize(self, text): + """ + Tokenizes a piece of text into characters. + + For example, `input = "apple""` wil return as output `["a", "p", "p", "l", "e"]`. + + Args: + text: A single token or whitespace separated tokens. + This should have already been passed through `BasicTokenizer`. + + Returns: + A list of characters. + """ + if self.normalize_text: + text = unicodedata.normalize("NFKC", text) + + output_tokens = [] + for char in text: + if char not in self.vocab: + output_tokens.append(self.unk_token) + continue + + output_tokens.append(char) + + return output_tokens diff --git a/tests/transformers/bert_japanese/__init__.py b/tests/transformers/bert_japanese/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/transformers/bert_japanese/test_tokenizer.py b/tests/transformers/bert_japanese/test_tokenizer.py new file mode 100644 index 000000000000..8db04a288b44 --- /dev/null +++ b/tests/transformers/bert_japanese/test_tokenizer.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import unittest +from paddlenlp.transformers import BertTokenizer, BertJapaneseTokenizer +from paddlenlp.data import Vocab + +from common_test import CpuCommonTest +from util import slow, assert_raises +import unittest + + +class TestBertJapaneseTokenizerFromPretrained(CpuCommonTest): + @slow + def test_from_pretrained(self): + tokenizer = BertJapaneseTokenizer.from_pretrained("bert-base-japanese") + text1 = "こんにちは" + text2 = "櫓を飛ばす" + # test batch_encode + expected_input_ids = [ + 2, 10350, 25746, 28450, 3, 20301, 11, 787, 12222, 3, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0 + ] + expected_token_type_ids = [ + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ] + expected_attention_mask = [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ] + expected_special_tokens_mask = [ + 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ] + results = tokenizer( + [text1], [text2], + 20, + stride=1, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_special_tokens_mask=True) + + self.check_output_equal(results[0]['input_ids'], expected_input_ids) + self.check_output_equal(results[0]['token_type_ids'], + expected_token_type_ids) + self.check_output_equal(results[0]['attention_mask'], + expected_attention_mask) + self.check_output_equal(results[0]['special_tokens_mask'], + expected_special_tokens_mask) + # test encode + results = tokenizer(text1, text2, 20, stride=1, pad_to_max_seq_len=True) + self.check_output_equal(results['input_ids'], expected_input_ids) + self.check_output_equal(results['token_type_ids'], + expected_token_type_ids) + + @slow + def test_from_pretrained_pad_left(self): + tokenizer = BertJapaneseTokenizer.from_pretrained("bert-base-japanese") + tokenizer.padding_side = "left" + text1 = "こんにちは" + text2 = "櫓を飛ばす" + # test batch_encode + expected_input_ids = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 10350, 25746, 28450, 3, 20301, 11, + 787, 12222, 3 + ] + expected_token_type_ids = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1 + ] + expected_attention_mask = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ] + expected_special_tokens_mask = [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1 + ] + results = tokenizer( + [text1], [text2], + 20, + stride=1, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_special_tokens_mask=True) + + self.check_output_equal(results[0]['input_ids'], expected_input_ids) + self.check_output_equal(results[0]['token_type_ids'], + expected_token_type_ids) + self.check_output_equal(results[0]['attention_mask'], + expected_attention_mask) + self.check_output_equal(results[0]['special_tokens_mask'], + expected_special_tokens_mask) + # test encode + results = tokenizer(text1, text2, 20, stride=1, pad_to_max_seq_len=True) + self.check_output_equal(results['input_ids'], expected_input_ids) + self.check_output_equal(results['token_type_ids'], + expected_token_type_ids) + + +if __name__ == "__main__": + unittest.main()