Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MBartTokenizer] remove dep on xlm-roberta tokenizer #15201

Merged
merged 1 commit into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 143 additions & 37 deletions src/transformers/models/mbart/tokenization_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import contextmanager
from typing import List, Optional
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

from ...tokenization_utils import BatchEncoding
import sentencepiece as spm

from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer


logger = logging.get_logger(__name__)

SPIECE_UNDERLINE = "▁"

VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}

Expand All @@ -38,41 +42,17 @@
"facebook/mbart-large-cc25": 1024,
}

FAIRSEQ_LANGUAGE_CODES = [
"ar_AR",
"cs_CZ",
"de_DE",
"en_XX",
"es_XX",
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
]


class MBartTokenizer(XLMRobertaTokenizer):
# fmt: off
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN"]
# fmt: on


class MBartTokenizer(PreTrainedTokenizer):
"""
Construct an MBART tokenizer.

[`MBartTokenizer`] is a subclass of [`XLMRobertaTokenizer`]. Refer to superclass [`XLMRobertaTokenizer`] for usage
examples and documentation concerning the initialization parameters and other methods.
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
[SentencePiece](https://github.com/google/sentencepiece).

The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
Expand All @@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
model_input_names = ["input_ids", "attention_mask"]

prefix_tokens: List[int] = []
suffix_tokens: List[int] = []

def __init__(
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
tokenizer_file=None,
src_lang=None,
tgt_lang=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
additional_special_tokens=None,
**kwargs
):

# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

super().__init__(
*args,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
tokenizer_file=tokenizer_file,
src_lang=src_lang,
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'

# Mimic fairseq token-to-id alignment for the first 4 token
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}

# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1

self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
Expand All @@ -132,6 +156,22 @@ def __init__(
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state

def __setstate__(self, d):
self.__dict__ = d

# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

@property
def vocab_size(self):
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
Expand Down Expand Up @@ -202,6 +242,31 @@ def build_inputs_with_special_tokens(
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of zeros.

"""

sep = [self.sep_token_id]
cls = [self.cls_token_id]

if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
Expand All @@ -214,6 +279,47 @@ def _build_translation_inputs(
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs

def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab

def _tokenize(self, text: str) -> List[str]:
return self.sp_model.encode(text, out_type=str)

def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)

# Need to return unknown token if the SP model returned 0
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)

if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)

return (out_vocab_file,)

def prepare_seq2seq_batch(
self,
src_texts: List[str],
Expand Down
Loading