Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Dict] Add extra special tokens #2828

Merged
merged 19 commits into from
Jul 16, 2020
5 changes: 5 additions & 0 deletions parlai/agents/special_tok/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
118 changes: 118 additions & 0 deletions parlai/agents/special_tok/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.utils.logging import logging

import functools
import torch
from typing import Dict, Any, List

SPECIAL_TOKS = 'PARTY,PARROT'


def recursive_getattr(obj, attr, *args):
"""
Recursive call to getattr for nested attributes.
"""

def _getattr(obj, attr):
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split('.'))


def add_common_args(argparser):
"""
Add cmdline args.
"""
argparser.add_argument(
'--special-tok-lst',
type=str,
default=SPECIAL_TOKS,
help='Comma separated list of special tokens',
)

return argparser


class SpecialTokenMixin:
"""
Mixin adding special tokens to the dictionary.
"""

def _get_special_tokens(self) -> List[str]:
"""
Return list of special tokens.

Made easily overridable for special cases.
"""
return self.opt['special_tok_lst'].split(',')

def build_dictionary(self):
"""
Return the constructed dictionary, which will be set to self.dict.

If you need to add additional tokens to the dictionary, this is likely the right
place to do it.
"""
d = self.dictionary_class()(self.opt)
d.add_extra_special_tokens(self._get_special_tokens())

return d

def _resize_token_embeddings(self):
"""
Must define this for your agent.

Must make a call to resize the token embeddings and load the model state dict.
"""
raise RuntimeError('Must define this funciton for your specific agent.')
emilydinan marked this conversation as resolved.
Show resolved Hide resolved

def load_state_dict(self, state_dict):
"""
Load the state dict into model.

Override from Torch Agent to resize the token embeddings.s
"""
try:
self.model.load_state_dict(state_dict)
return False
except RuntimeError as msg:
msg_ = str(msg)
if 'size mismatch' in msg_ and 'embedding' in msg_:
self._resize_token_embeddings(state_dict, msg_)
return True # resized
else:
raise (msg)

def init_optim(self, params, optim_states=None, saved_optim_type=None):
"""
Override: do not load optimizer state if resized.
"""
if hasattr(self, 'resized') and self.resized:
optim_states = None
logging.warn('Not loading optimizer due to resize in token embeddings')

super().init_optim(params, optim_states, saved_optim_type)

def load(self, path: str) -> Dict[str, Any]:
"""
Return opt and model states.

Override this method to catch a resize
"""
import parlai.utils.pickle

states = torch.load(
path, map_location=lambda cpu, _: cpu, pickle_module=parlai.utils.pickle
)
self.resized = False
if 'model' in states:
self.resized = self.load_state_dict(states['model'])
if 'optimizer' in states and hasattr(self, 'optimizer'):
self.optimizer.load_state_dict(states['optimizer'])

return states
48 changes: 48 additions & 0 deletions parlai/agents/special_tok/transformer_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.agents.transformer.transformer import TransformerGeneratorAgent as Base
from .agents import SpecialTokenMixin, add_common_args, recursive_getattr

from parlai.utils.logging import logging


class TransformerGeneratorAgent(SpecialTokenMixin, Base):
"""
TransformerGeneratorAgent with special tokens added.
"""

@classmethod
def add_cmdline_args(cls, argparser):
argparser = add_common_args(argparser)
argparser = super(cls, TransformerGeneratorAgent).add_cmdline_args(argparser)
return argparser

def _resize_token_embeddings(self, state_dict, msg=None):
# map extra special tokens carefully
new_size = self.model.embeddings.weight.size()[0]
orig_size = state_dict['embeddings.weight'].size()[0]
logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
if new_size <= orig_size:
# new size should be greater than original size,
# as we are adding special tokens
raise RuntimeError(msg)

for emb_weights in [
'embeddings.weight',
'encoder.embeddings.weight',
'decoder.embeddings.weight',
]:
# get new_embs
old_embs = state_dict[emb_weights]
new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device)
# copy over old weights
new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
# reset in state dict
state_dict[emb_weights] = new_embs

# now try loading again
self.model.load_state_dict(state_dict)
38 changes: 36 additions & 2 deletions parlai/core/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import re
import parlai.utils.logging as logging
from typing import List

RETOK = re.compile(r'\w+|[^\w\s]|\n', re.UNICODE)

Expand Down Expand Up @@ -324,6 +325,35 @@ def __init__(self, opt: Opt, shared=None):
if opt.get('dict_file'):
self.save_path = opt['dict_file']

def add_extra_special_tokens(self, extra_special_tokens: List[str]):
emilydinan marked this conversation as resolved.
Show resolved Hide resolved
"""
TODO: more details

Should only be called after initialization of the existing dictionary.
"""
self.extra_special_tokens = extra_special_tokens

if self.extra_special_tokens and not self.supports_extra_special_tokens():
raise RuntimeError(
f'{self.tokenizer} does not currently support adding additional special tokens'
)

for tok in self.extra_special_tokens:
self.add_token(tok)

for i, tok in enumerate(self.extra_special_tokens):
self.freq[tok] = 1000000000 + 4 + i

if self.tokenizer == 'bytelevelbpe':
self.bpe.add_special_tokens(self, self.extra_special_tokens)

def supports_extra_special_tokens(self):
"""
Indicates whether the dictionary supports additional special tokens.
"""
# TODO: add to others
return self.tokenizer in ['bytelevelbpe', 'split', 'space']

def is_prebuilt(self):
"""
Indicates whether the dictionary is fixed, and does not require building.
Expand Down Expand Up @@ -708,9 +738,13 @@ def vec2txt(self, vector, delimiter=' '):
text = self.bpe.decode(tokens, vector, delimiter)
elif self.tokenizer == 'bytelevelbpe':
# We add special tokens in the beginning of ParlAI dict but in the
# end of Hugging Face dict,there is an offset of 4 between them.
# end of Hugging Face dict, there is an offset of #(extra tokens) between them.
extra_tokens = 4 # length of special tokens
klshuster marked this conversation as resolved.
Show resolved Hide resolved
vector = [
idx + len(self.tok2ind) - 4 if idx < 4 else idx - 4 for idx in vector
self.bpe.special_tok_map[idx]
if idx in self.bpe.special_tok_map
else idx - extra_tokens
for idx in vector
]
tokens = [self[int(idx)] for idx in vector]
text = self.bpe.decode(tokens, vector, delimiter)
Expand Down
28 changes: 25 additions & 3 deletions parlai/utils/bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def add_cmdline_args(argparser):
hidden=True,
help='add prefix space before encoding',
)
parser.add_argument(
'--hf-skip-special-tokens',
hidden=True,
type='bool',
default=True,
help='do not decode special tokens with bytelevelbpe',
)
return parser

@final
Expand Down Expand Up @@ -689,7 +696,9 @@ class HuggingFaceBpeHelper(BPEHelper):
def __init__(self, opt: Opt, shared: TShared = None):
super().__init__(opt, shared)
# Default true for HF
self.special_tok_map = {} # map from HF
self.add_prefix_space = opt.get('bpe_add_prefix_space', True)
self.skip_special_tokens = opt.get('hf_skip_special_tokens', True)
if self.add_prefix_space is None:
self.add_prefix_space = True
if opt.get('dict_loaded'):
Expand Down Expand Up @@ -769,9 +778,21 @@ def helper_decode(
:return text:
decoded text
"""
text = self.tokenizer.decode(token_ids)
text = self.tokenizer.decode(
token_ids, skip_special_tokens=self.skip_special_tokens
)

return text

def add_special_tokens(self, dict_agent, special_tokens: List[str]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_special_tokens(self, dict_agent, special_tokens: List[str]):
def add_special_tokens(self, dict_agent, special_tokens: List[str]):
"""
Add special tokens to the tokenizer and dict_agent.
"""

logging.info(f'adding the following special tokens: {special_tokens}')
self.tokenizer.add_special_tokens(special_tokens) # add to HF

for tok in special_tokens:
parlai_key = dict_agent[tok]
hf_key = self.tokenizer.token_to_id(tok)
self.special_tok_map[parlai_key] = hf_key

def sync_with_dict(self, dict_agent):
"""
Sync the dictionary agent with Hugging Face tokenizer's BPE dict.
Expand All @@ -784,8 +805,9 @@ def sync_with_dict(self, dict_agent):
dict_agent.end_token,
dict_agent.unk_token,
]
self.tokenizer.add_special_tokens(special_tokens)
for i in range(self.tokenizer.get_vocab_size() - 4):
self.add_special_tokens(dict_agent, special_tokens)

for i in range(self.tokenizer.get_vocab_size() - len(special_tokens)):
token = self.tokenizer.id_to_token(i)
dict_agent.add_token(token)
# We don't have access to the hugging face word frequency table,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,30 @@ def test_save_reload(self):
)
assert da2.txt2vec("hello") == da.txt2vec("hello")

def test_add_special_tokens(self):
emilydinan marked this conversation as resolved.
Show resolved Hide resolved
"""
Add a list of special tokens to the dictionary.
"""
special_toks_lst = ['MY', 'NAME', 'IS', 'EMILY']
# create Dictionary Agent
parser = ParlaiParser()
parser.set_params(
dict_tokenizer='bytelevelbpe',
bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
hf_skip_special_tokens=False,
)
opt = parser.parse_args([], print_args=False)
emilydinan marked this conversation as resolved.
Show resolved Hide resolved
agent = DictionaryAgent(opt)
agent.add_extra_special_tokens(special_toks_lst)

self.assertEqual(agent.extra_special_tokens, special_toks_lst)
phrases = ['Hi what is up EMILY', 'What IS your NAME', 'That is MY dog']
for phrase in phrases:
vec = agent.txt2vec(phrase)
text = agent.vec2txt(vec)
self.assertEqual(phrase, text)


class TestBuildDict(unittest.TestCase):
def _run_test(self, opt):
Expand Down