Skip to content

Commit

Permalink
Revert "fix t5 special tokens (huggingface#8435)"
Browse files Browse the repository at this point in the history
This reverts commit 2221e9a.
  • Loading branch information
fabiocapsouza authored Nov 15, 2020
1 parent 54dc65e commit 666f6fb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 26 deletions.
13 changes: 2 additions & 11 deletions src/transformers/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,8 @@ def _convert_id_to_token(self, index):

def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()
out_string = self.sp_model.decode_pieces(tokens)
return out_string

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
15 changes: 0 additions & 15 deletions tests/test_tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,3 @@ def test_eos_in_input(self):

self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)

def test_fast_and_slow_same_result(self):
src_text = "<pad> Today is <unk> nice day </s>"
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
tgt_text = "<pad> Today is<unk> nice day</s>"

fast_ids = self.t5_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids
slow_ids = self.t5_base_tokenizer(src_text, add_special_tokens=False).input_ids
self.assertEqual(tgt_ids, fast_ids)
self.assertEqual(tgt_ids, slow_ids)

fast_text = self.t5_base_tokenizer_fast.decode(fast_ids)
slow_text = self.t5_base_tokenizer.decode(fast_ids)
self.assertEqual(tgt_text, fast_text)
self.assertEqual(tgt_text, slow_text)

0 comments on commit 666f6fb

Please sign in to comment.