Skip to content

Commit

Permalink
add sp_model_kwargs to unpickle of xlm roberta tok
Browse files Browse the repository at this point in the history
add test for pickle

simplify test

fix test code style

add missing pickle import

fix test

fix test

fix test
  • Loading branch information
PhilipMay committed Apr 25, 2021
1 parent 52166f6 commit acf636c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

super().__init__(
bos_token=bos_token,
Expand All @@ -145,11 +145,11 @@ def __init__(
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
sp_model_kwargs=sp_model_kwargs,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

Expand All @@ -175,7 +175,12 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()

# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def build_inputs_with_special_tokens(
Expand Down
13 changes: 13 additions & 0 deletions tests/test_tokenization_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import itertools
import os
import pickle
import unittest

from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
Expand Down Expand Up @@ -142,6 +143,18 @@ def test_subword_regularization_tokenizer(self):

self.assertFalse(all_equal)

def test_pickle_subword_regularization_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
# Subword regularization is only available for the slow tokenizer.
sp_model_kwargs = {"enable_sampling": True, "alpha": 0.1, "nbest_size": -1}
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs=sp_model_kwargs)
tokenizer_bin = pickle.dumps(tokenizer)
tokenizer_new = pickle.loads(tokenizer_bin)

self.assertIsNotNone(tokenizer_new.sp_model_kwargs)
self.assertTrue(isinstance(tokenizer_new.sp_model_kwargs, dict))
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)

@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
Expand Down

0 comments on commit acf636c

Please sign in to comment.