From a2adbe146a30acf24462342fe682250133d5e6de Mon Sep 17 00:00:00 2001 From: Jing Date: Tue, 15 Dec 2020 20:33:37 -0500 Subject: [PATCH] fix dialogpt dual usage of END_IDX (#3256) * fix dialogpt dual usage of endoftext * add null_idx = -1 * dialog bs test * Set null_idx in model and decoder, add to dialogpt test * small formats * accidental delete old test * reviewer comment --- parlai/agents/hugging_face/dialogpt.py | 22 ++++++ parlai/agents/hugging_face/gpt2.py | 4 +- tests/nightly/gpu/test_dialogpt.py | 105 ++++++++++++++++++++++++- tests/nightly/gpu/test_gpt2.py | 2 + 4 files changed, 130 insertions(+), 3 deletions(-) diff --git a/parlai/agents/hugging_face/dialogpt.py b/parlai/agents/hugging_face/dialogpt.py index 4023b0af8d7..5f95ec59e3c 100644 --- a/parlai/agents/hugging_face/dialogpt.py +++ b/parlai/agents/hugging_face/dialogpt.py @@ -26,6 +26,24 @@ class DialoGPTDecoder(GPT2Decoder): This decoder is initialized with the pretrained model from Hugging Face. """ + def __init__(self, opt, dict): + super().__init__(opt, dict) + self.NULL_IDX, self.START_IDX, self.END_IDX = self._get_special_tokens( + opt, dict + ) + + @staticmethod + def _get_special_tokens(opt, dict): + null_idx = dict.null_idx + if ( + opt.get('batchsize', 1) == 1 + and not opt['add_special_tokens'] + and null_idx == dict.end_idx + ): + # get around the dual usage of end_idx that would otherwise mask endtoken during forward pass. + null_idx = -1 + return null_idx, dict.start_idx, dict.end_idx + def _init_from_pretrained(self, opt): # load model model_sz = opt['gpt2_size'] @@ -38,6 +56,10 @@ class DialoGPTModel(HFGPT2Model): Hugging Face DialoGPT Model. """ + def _get_special_tokens(self, opt, dict): + # keep it consistent between DialoGPTModel and DialoGPTDecoder on start_idx, end_idx, null_idx + return DialoGPTDecoder._get_special_tokens(opt, dict) + def _get_decoder(self, opt, dict): return DialoGPTDecoder(opt, dict) diff --git a/parlai/agents/hugging_face/gpt2.py b/parlai/agents/hugging_face/gpt2.py index 0ef175fac16..1297cde9015 100644 --- a/parlai/agents/hugging_face/gpt2.py +++ b/parlai/agents/hugging_face/gpt2.py @@ -85,7 +85,8 @@ def forward(self, input, encoder_state, incr_state=None): and int(input[0][0]) == self.START_IDX ): # generating: ignore the start token - model_input = encoder_state + # without deep copy, the padding_idx (-1) in encoder_state can be reset to 0 with clamp_ inplace operation + model_input = encoder_state.clone() else: # forced decoding: concatenate the context # with the labels @@ -108,6 +109,7 @@ def forward(self, input, encoder_state, incr_state=None): model_input = input[:, -1:] attention_mask = torch.cat([encoder_state, input], dim=-1) != self.NULL_IDX + model_input = model_input.clamp_(min=0) transformer_outputs = self.transformer( model_input, past=incr_state, diff --git a/tests/nightly/gpu/test_dialogpt.py b/tests/nightly/gpu/test_dialogpt.py index 7f72ffed228..12f5034a1cd 100644 --- a/tests/nightly/gpu/test_dialogpt.py +++ b/tests/nightly/gpu/test_dialogpt.py @@ -6,18 +6,119 @@ import unittest import parlai.utils.testing as testing_utils +from parlai.core.agents import create_agent @testing_utils.skipUnlessGPU class TestDialogptModel(unittest.TestCase): """ Test of DialoGPT model. - - Checks that DialoGPT gets a certain performance on the integration test task. """ + def _test_batchsize(self, batchsize, add_start_token): + utterances = [ + 'How is your day so far?', + 'I hope you you have a good day.', + "Nice to meet you. My name is John. ", + "I've got a feeling we're not in Kansas anymore.", + ] + opt = { + 'model': 'hugging_face/dialogpt', + 'gpt2_size': 'small', + 'text_truncate': 100, + 'label_truncate': 20, + 'beam_min_length': 1, + 'inference': 'beam', + 'beam_size': 1, + 'add_special_tokens': True, + 'batchsize': batchsize, + 'add_start_token': add_start_token, + } + dialogpt = create_agent(opt) + + results_single = [] + agents = [dialogpt.clone() for _ in utterances] + for u, a in zip(utterances, agents): + a.observe({'text': u, 'episode_done': True}) + generation = a.act()['text'] + results_single.append(generation) + + results_batched = [] + for idx in range(len(utterances) // batchsize): + agents = [dialogpt.clone() for _ in range(batchsize)] + batch = utterances[idx * batchsize : (idx + 1) * batchsize] + obs = [] + for i, a in enumerate(agents): + obs.append(a.observe({'text': batch[i], 'episode_done': True})) + generations = [x['text'] for x in dialogpt.batch_act(obs)] + results_batched += generations + + assert results_single == results_batched + + def test_batchsize(self): + """ + Ensures dialogpt provides the same generation results regardless of batchsize. + """ + # Test throwing the RuntimeError with add_special_tokens = False and batchsize > 1 + with self.assertRaises(RuntimeError): + create_agent( + { + 'model': 'hugging_face/dialogpt', + 'add_special_tokens': False, + 'batchsize': 2, + } + ) + + for batchsize in [1, 2, 4]: + for add_start_token in [True, False]: + with self.subTest( + f'test_batchsize with bs={batchsize} and add_start_token={add_start_token}' + ): + self._test_batchsize(batchsize, add_start_token) + + def test_start_token(self): + """ + Test RuntimeError is thrown when add_start_token = True and yet add_special_tokens = False + """ + with self.assertRaises(RuntimeError): + create_agent( + { + 'model': 'hugging_face/dialogpt', + 'add_special_tokens': False, + 'add_start_token': True, + } + ) + + def test_nospecialtok(self): + """ + Test generation consistency for off-the-shelf dialogpt models. + """ + test_cases = [ + ("What a nice weather!", "I'm in the UK and it's raining here."), + ("Nice to meet you!", "Hello! I'm from the future!"), + ] + opt = { + 'model': 'hugging_face/dialogpt', + 'gpt2_size': 'small', + 'text_truncate': 100, + 'label_truncate': 20, + 'beam_min_length': 1, + 'inference': 'beam', + 'beam_size': 1, + 'add_special_tokens': False, + 'batchsize': 1, + } + dialogpt = create_agent(opt) + for text, label in test_cases: + dialogpt.observe({'text': text, 'episode_done': True}) + response = dialogpt.act() + assert response['text'] == label + @testing_utils.retry(ntries=3, log_retry=True) def test_dialogpt(self): + """ + Checks that DialoGPT gets a certain performance on the integration test task. + """ valid, test = testing_utils.train_model( dict( task='integration_tests:overfit', diff --git a/tests/nightly/gpu/test_gpt2.py b/tests/nightly/gpu/test_gpt2.py index 1e98c061f39..1c5b8739eb7 100644 --- a/tests/nightly/gpu/test_gpt2.py +++ b/tests/nightly/gpu/test_gpt2.py @@ -36,6 +36,8 @@ def test_custom_special_tokens(self): class TestGpt2(unittest.TestCase): + # Did you implement a test for DialoGPT too if your changes affect it? + def _test_batchsize(self, batchsize, add_start_token): utterances = [ 'Just keep swimming -',