diff --git a/parlai/agents/transformer/transformer.py b/parlai/agents/transformer/transformer.py index d6162da3dd4..041c9c0748b 100644 --- a/parlai/agents/transformer/transformer.py +++ b/parlai/agents/transformer/transformer.py @@ -10,6 +10,8 @@ from parlai.core.torch_classifier_agent import TorchClassifierAgent from parlai.core.torch_ranker_agent import TorchRankerAgent from parlai.core.torch_generator_agent import TorchGeneratorAgent +from parlai.utils.misc import recursive_getattr +from parlai.utils.logging import logging from .modules import ( TransformerMemNetModel, @@ -326,6 +328,34 @@ def build_model(self, states=None): ) return model + def _resize_token_embeddings(self, state_dict, msg=None): + """ + Resize the token embeddings when are adding extra special tokens. + """ + # map extra special tokens carefully + new_size = self.model.embeddings.weight.size()[0] + orig_size = state_dict['embeddings.weight'].size()[0] + logging.info(f'Resizing token embeddings from {orig_size} to {new_size}') + if new_size <= orig_size: + # new size should be greater than original size, + # as we are adding special tokens + raise RuntimeError(msg) + + for emb_weights in [ + 'embeddings.weight', + 'encoder.embeddings.weight', + 'decoder.embeddings.weight', + ]: + # get new_embs + old_embs = state_dict[emb_weights] + new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device) + # copy over old weights + new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :] + # reset in state dict + state_dict[emb_weights] = new_embs + + return state_dict + class TransformerClassifierAgent(TorchClassifierAgent): """ diff --git a/parlai/core/dict.py b/parlai/core/dict.py index d68d4841d7c..b13260b6402 100644 --- a/parlai/core/dict.py +++ b/parlai/core/dict.py @@ -20,6 +20,7 @@ import json import re import parlai.utils.logging as logging +from typing import List RETOK = re.compile(r'\w+|[^\w\s]|\n', re.UNICODE) @@ -324,6 +325,38 @@ def __init__(self, opt: Opt, shared=None): if opt.get('dict_file'): self.save_path = opt['dict_file'] + def add_additional_special_tokens(self, additional_special_tokens: List[str]): + """ + Add additional special tokens to the dictionary. + + Should only be called after initialization of the existing dictionary. + """ + self.additional_special_tokens = additional_special_tokens + + if ( + self.additional_special_tokens + and not self.supports_additional_special_tokens() + ): + raise RuntimeError( + f'{self.tokenizer} does not currently support adding additional special tokens' + ) + + for tok in self.additional_special_tokens: + self.add_token(tok) + + for i, tok in enumerate(self.additional_special_tokens): + self.freq[tok] = 1000000000 + 4 + i + + if self.tokenizer == 'bytelevelbpe': + self.bpe.add_special_tokens(self, self.additional_special_tokens) + + def supports_additional_special_tokens(self): + """ + Indicates whether the dictionary supports additional special tokens. + """ + # TODO: add to others + return self.tokenizer in ['bytelevelbpe', 'split', 'space'] + def is_prebuilt(self): """ Indicates whether the dictionary is fixed, and does not require building. @@ -708,9 +741,13 @@ def vec2txt(self, vector, delimiter=' '): text = self.bpe.decode(tokens, vector, delimiter) elif self.tokenizer == 'bytelevelbpe': # We add special tokens in the beginning of ParlAI dict but in the - # end of Hugging Face dict,there is an offset of 4 between them. + # end of Hugging Face dict, there is an offset of #(extra tokens) between them. + extra_tokens = 4 # length of special tokens vector = [ - idx + len(self.tok2ind) - 4 if idx < 4 else idx - 4 for idx in vector + self.bpe.special_tok_map[idx] + if idx in self.bpe.special_tok_map + else idx - extra_tokens + for idx in vector ] tokens = [self[int(idx)] for idx in vector] text = self.bpe.decode(tokens, vector, delimiter) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 7004aff6a5d..29d25f5d2aa 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -653,6 +653,12 @@ def add_cmdline_args(cls, argparser): choices=[None, 'end'], help='Add special token to the end of history encoding.', ) + agent.add_argument( + '--special-tok-lst', + type=str, + default=None, + help='Comma separated list of special tokens', + ) # GPU arguments # these gpu options are all mutually exclusive, and should error if the # user tries to present multiple of them @@ -801,11 +807,28 @@ def build_dictionary(self): place to do it. """ d = self.dictionary_class()(self.opt) + self.special_toks = self._get_special_tokens() + if self.special_toks: + d.add_additional_special_tokens(self.special_toks) + if self.opt.get('person_tokens'): d[self.P1_TOKEN] = 999_999_999 d[self.P2_TOKEN] = 999_999_998 return d + def _resize_token_embeddings(self, state_dict, msg=None): + """ + Must define this for your agent if you wish to add additional special tokens. + + Must make a call to resize the token embeddings and load the model state dict + with the resized token embeddings. + """ + raise NotImplementedError( + 'If you are intending to add special tokens to an already pretrained model, ' + 'you must write the function `_resize_token_embeddings` for your specific ' + 'agent.' + ) + def _get_init_model(self, opt: Opt, shared): """ Get model file to initialize with. @@ -845,6 +868,16 @@ def _get_init_model(self, opt: Opt, shared): return init_model, is_finetune + def _get_special_tokens(self) -> List[str]: + """ + Return list of special tokens. + + Made easily overridable for special cases. + """ + if self.opt.get('special_tok_lst') is not None: + return self.opt['special_tok_lst'].split(',') + return [] + @abstractmethod def build_model(self): """ @@ -878,6 +911,10 @@ def init_optim(self, params, optim_states=None, saved_optim_type=None): type of optimizer being loaded, if changed will skip loading optimizer states """ + if hasattr(self, 'resized_embeddings') and self.resized_embeddings: + optim_states = None + logging.warn('Not loading optimizer due to resize in token embeddings') + opt = self.opt # set up optimizer args @@ -1810,14 +1847,19 @@ def load_state_dict(self, state_dict): except RuntimeError as msg: msg_ = str(msg) if 'size mismatch' in msg_ and 'embedding' in msg_: - raise RuntimeError( - f'{msg_}\n' - '-----------------\n' - 'Could not load the model due to a size mismatch in the ' - 'embeddings. A common reason for this is trying to load ' - 'a model trained with fp16 but loaded without fp16. Try ' - 'adding --fp16 true or --force-fp16-tokens true.' - ) + if hasattr(self, 'special_toks') and len(self.special_toks) > 0: + state_dict = self._resize_token_embeddings(state_dict, msg_) + self.model.load_state_dict(state_dict) + self.resized_embeddings = True # make note that we resized here + else: + raise RuntimeError( + f'{msg_}\n' + '-----------------\n' + 'Could not load the model due to a size mismatch in the ' + 'embeddings. A common reason for this is trying to load ' + 'a model trained with fp16 but loaded without fp16. Try ' + 'adding --fp16 true or --force-fp16-tokens true.' + ) else: raise diff --git a/parlai/utils/bpe.py b/parlai/utils/bpe.py index b758dfdd71b..aea39426716 100644 --- a/parlai/utils/bpe.py +++ b/parlai/utils/bpe.py @@ -122,6 +122,13 @@ def add_cmdline_args(argparser): hidden=True, help='add prefix space before encoding', ) + parser.add_argument( + '--hf-skip-special-tokens', + hidden=True, + type='bool', + default=True, + help='do not decode special tokens with bytelevelbpe', + ) return parser @final @@ -689,7 +696,9 @@ class HuggingFaceBpeHelper(BPEHelper): def __init__(self, opt: Opt, shared: TShared = None): super().__init__(opt, shared) # Default true for HF + self.special_tok_map = {} # map from HF self.add_prefix_space = opt.get('bpe_add_prefix_space', True) + self.skip_special_tokens = opt.get('hf_skip_special_tokens', True) if self.add_prefix_space is None: self.add_prefix_space = True if opt.get('dict_loaded'): @@ -769,9 +778,24 @@ def helper_decode( :return text: decoded text """ - text = self.tokenizer.decode(token_ids) + text = self.tokenizer.decode( + token_ids, skip_special_tokens=self.skip_special_tokens + ) + return text + def add_special_tokens(self, dict_agent, special_tokens: List[str]): + """ + Add special tokens to the tokenizer and dict_agent. + """ + logging.info(f'adding the following special tokens: {special_tokens}') + self.tokenizer.add_special_tokens(special_tokens) # add to HF + + for tok in special_tokens: + parlai_key = dict_agent[tok] + hf_key = self.tokenizer.token_to_id(tok) + self.special_tok_map[parlai_key] = hf_key + def sync_with_dict(self, dict_agent): """ Sync the dictionary agent with Hugging Face tokenizer's BPE dict. @@ -784,8 +808,9 @@ def sync_with_dict(self, dict_agent): dict_agent.end_token, dict_agent.unk_token, ] - self.tokenizer.add_special_tokens(special_tokens) - for i in range(self.tokenizer.get_vocab_size() - 4): + self.add_special_tokens(dict_agent, special_tokens) + + for i in range(self.tokenizer.get_vocab_size() - len(special_tokens)): token = self.tokenizer.id_to_token(i) dict_agent.add_token(token) # We don't have access to the hugging face word frequency table, diff --git a/parlai/utils/misc.py b/parlai/utils/misc.py index 31d9662f1e1..2ac7db87f5e 100644 --- a/parlai/utils/misc.py +++ b/parlai/utils/misc.py @@ -10,6 +10,7 @@ from collections import deque, OrderedDict from typing import Union, Optional, Set, Any, Dict, List, Tuple from datetime import timedelta +import functools import math import time import re @@ -752,3 +753,14 @@ def error_once(msg: str) -> None: if msg not in _seen_logs: _seen_logs.add(msg) logging.error(msg) + + +def recursive_getattr(obj, attr, *args): + """ + Recursive call to getattr for nested attributes. + """ + + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) diff --git a/tests/test_dict.py b/tests/test_dict.py index 9a4040e9e7b..9f9f72463ef 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -379,6 +379,31 @@ def test_save_reload(self): ) assert da2.txt2vec("hello") == da.txt2vec("hello") + def test_add_special_tokens(self): + """ + Add a list of special tokens to the dictionary. + """ + special_toks_lst = ['MY', 'NAME', 'IS', 'EMILY'] + # create Dictionary Agent + parser = ParlaiParser() + parser.set_params( + dict_tokenizer='bytelevelbpe', + bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB, + bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE, + hf_skip_special_tokens=False, + ) + opt = parser.parse_args([], print_args=False) + + agent = DictionaryAgent(opt) + agent.add_additional_special_tokens(special_toks_lst) + + self.assertEqual(agent.additional_special_tokens, special_toks_lst) + phrases = ['Hi what is up EMILY', 'What IS your NAME', 'That is MY dog'] + for phrase in phrases: + vec = agent.txt2vec(phrase) + text = agent.vec2txt(vec) + self.assertEqual(phrase, text) + class TestBuildDict(unittest.TestCase): def _run_test(self, opt): diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 473fd36e964..fab953957f8 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -13,6 +13,8 @@ import parlai.utils.testing as testing_utils from parlai.core.agents import create_agent from parlai.core.opt import Opt +from tests.test_dict import DEFAULT_BYTELEVEL_BPE_VOCAB, DEFAULT_BYTELEVEL_BPE_MERGE +from parlai.core.params import ParlaiParser class TestTransformerRanker(unittest.TestCase): @@ -674,6 +676,50 @@ def test_temperature(self): ) ) + def test_resize_embeddings(self): + # train original model + with testing_utils.tempdir() as tmpdir: + model_file = os.path.join(tmpdir, 'model_file') + _, _ = testing_utils.train_model( + dict( + model='transformer/generator', + task='integration_tests:short_fixed', + n_layers=1, + n_encoder_layers=2, + n_decoder_layers=4, + num_epochs=1, + dict_tokenizer='bytelevelbpe', + bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB, + bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE, + bpe_add_prefix_space=False, + model_file=model_file, + save_after_valid=True, + ) + ) + + # now create agent with special tokens + parser = ParlaiParser() + parser.set_params( + model='transformer/generator', + task='integration_tests:short_fixed', + n_layers=1, + n_encoder_layers=2, + n_decoder_layers=4, + dict_tokenizer='bytelevelbpe', + bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB, + bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE, + bpe_add_prefix_space=False, + model_file=model_file, + save_after_valid=True, + special_tok_lst='PARTY,PARROT', + ) + opt = parser.parse_args([], print_args=False) + agent = create_agent(opt) + # assert that the embeddings were resized + assert agent.resized_embeddings + # assert model has special tokens + self.assertEqual(agent.special_toks, ['PARTY', 'PARROT']) + class TestClassifier(unittest.TestCase): """