Skip to content

Commit

Permalink
Preserve list type of additional_special_tokens in `special_token…
Browse files Browse the repository at this point in the history
…_map` (#12759)

* preserve type of `additional_special_tokens` in `special_token_map`

* format

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
SaulLu and sgugger authored Jul 16, 2021
1 parent fbf1397 commit 6e87010
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,11 @@ def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = str(attr_value)
set_attr[attr] = (
type(attr_value)(str(attr_value_sub) for attr_value_sub in attr_value)
if isinstance(attr_value, (list, tuple))
else str(attr_value)
)
return set_attr

@property
Expand Down
4 changes: 4 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,10 @@ def test_add_tokens(self):
self.assertEqual(
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
)
self.assertIn("<testtoken3>", tokenizer_r.special_tokens_map["additional_special_tokens"])
self.assertIsInstance(tokenizer_r.special_tokens_map["additional_special_tokens"], list)
self.assertGreaterEqual(len(tokenizer_r.special_tokens_map["additional_special_tokens"]), 2)

self.assertEqual(len(tokenizer_r), vocab_size + 8)

def test_offsets_mapping(self):
Expand Down

0 comments on commit 6e87010

Please sign in to comment.