Skip to content

Commit

Permalink
fix pickle of barthez & camembert
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Apr 26, 2021
1 parent 9653e8e commit 05e9d86
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
13 changes: 9 additions & 4 deletions src/transformers/models/barthez/tokenization_barthez.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

super().__init__(
bos_token=bos_token,
Expand All @@ -140,12 +140,12 @@ def __init__(
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
sp_model_kwargs=sp_model_kwargs,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)

self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))

self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
Expand Down Expand Up @@ -261,7 +261,12 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()

# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/camembert/tokenization_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

super().__init__(
bos_token=bos_token,
Expand All @@ -138,10 +138,10 @@ def __init__(
pad_token=pad_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=sp_model_kwargs,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
Expand Down Expand Up @@ -261,7 +261,12 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()

# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tokenization_barthez.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@require_tokenizers
@require_sentencepiece
@slow
@slow # see https://github.com/huggingface/transformers/issues/11457
class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = BarthezTokenizer
Expand Down

0 comments on commit 05e9d86

Please sign in to comment.