Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
fix dialogpt dual usage of END_IDX (#3256)
Browse files Browse the repository at this point in the history
* fix dialogpt dual usage of endoftext

* add null_idx = -1

* dialog bs test

* Set null_idx in model and decoder, add to dialogpt test

* small formats

* accidental delete old test

* reviewer comment
  • Loading branch information
Jing authored Dec 16, 2020
1 parent 21f4dd3 commit a2adbe1
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 3 deletions.
22 changes: 22 additions & 0 deletions parlai/agents/hugging_face/dialogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ class DialoGPTDecoder(GPT2Decoder):
This decoder is initialized with the pretrained model from Hugging Face.
"""

def __init__(self, opt, dict):
super().__init__(opt, dict)
self.NULL_IDX, self.START_IDX, self.END_IDX = self._get_special_tokens(
opt, dict
)

@staticmethod
def _get_special_tokens(opt, dict):
null_idx = dict.null_idx
if (
opt.get('batchsize', 1) == 1
and not opt['add_special_tokens']
and null_idx == dict.end_idx
):
# get around the dual usage of end_idx that would otherwise mask endtoken during forward pass.
null_idx = -1
return null_idx, dict.start_idx, dict.end_idx

def _init_from_pretrained(self, opt):
# load model
model_sz = opt['gpt2_size']
Expand All @@ -38,6 +56,10 @@ class DialoGPTModel(HFGPT2Model):
Hugging Face DialoGPT Model.
"""

def _get_special_tokens(self, opt, dict):
# keep it consistent between DialoGPTModel and DialoGPTDecoder on start_idx, end_idx, null_idx
return DialoGPTDecoder._get_special_tokens(opt, dict)

def _get_decoder(self, opt, dict):
return DialoGPTDecoder(opt, dict)

Expand Down
4 changes: 3 additions & 1 deletion parlai/agents/hugging_face/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(self, input, encoder_state, incr_state=None):
and int(input[0][0]) == self.START_IDX
):
# generating: ignore the start token
model_input = encoder_state
# without deep copy, the padding_idx (-1) in encoder_state can be reset to 0 with clamp_ inplace operation
model_input = encoder_state.clone()
else:
# forced decoding: concatenate the context
# with the labels
Expand All @@ -108,6 +109,7 @@ def forward(self, input, encoder_state, incr_state=None):
model_input = input[:, -1:]
attention_mask = torch.cat([encoder_state, input], dim=-1) != self.NULL_IDX

model_input = model_input.clamp_(min=0)
transformer_outputs = self.transformer(
model_input,
past=incr_state,
Expand Down
105 changes: 103 additions & 2 deletions tests/nightly/gpu/test_dialogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,119 @@

import unittest
import parlai.utils.testing as testing_utils
from parlai.core.agents import create_agent


@testing_utils.skipUnlessGPU
class TestDialogptModel(unittest.TestCase):
"""
Test of DialoGPT model.
Checks that DialoGPT gets a certain performance on the integration test task.
"""

def _test_batchsize(self, batchsize, add_start_token):
utterances = [
'How is your day so far?',
'I hope you you have a good day.',
"Nice to meet you. My name is John. ",
"I've got a feeling we're not in Kansas anymore.",
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': True,
'batchsize': batchsize,
'add_start_token': add_start_token,
}
dialogpt = create_agent(opt)

results_single = []
agents = [dialogpt.clone() for _ in utterances]
for u, a in zip(utterances, agents):
a.observe({'text': u, 'episode_done': True})
generation = a.act()['text']
results_single.append(generation)

results_batched = []
for idx in range(len(utterances) // batchsize):
agents = [dialogpt.clone() for _ in range(batchsize)]
batch = utterances[idx * batchsize : (idx + 1) * batchsize]
obs = []
for i, a in enumerate(agents):
obs.append(a.observe({'text': batch[i], 'episode_done': True}))
generations = [x['text'] for x in dialogpt.batch_act(obs)]
results_batched += generations

assert results_single == results_batched

def test_batchsize(self):
"""
Ensures dialogpt provides the same generation results regardless of batchsize.
"""
# Test throwing the RuntimeError with add_special_tokens = False and batchsize > 1
with self.assertRaises(RuntimeError):
create_agent(
{
'model': 'hugging_face/dialogpt',
'add_special_tokens': False,
'batchsize': 2,
}
)

for batchsize in [1, 2, 4]:
for add_start_token in [True, False]:
with self.subTest(
f'test_batchsize with bs={batchsize} and add_start_token={add_start_token}'
):
self._test_batchsize(batchsize, add_start_token)

def test_start_token(self):
"""
Test RuntimeError is thrown when add_start_token = True and yet add_special_tokens = False
"""
with self.assertRaises(RuntimeError):
create_agent(
{
'model': 'hugging_face/dialogpt',
'add_special_tokens': False,
'add_start_token': True,
}
)

def test_nospecialtok(self):
"""
Test generation consistency for off-the-shelf dialogpt models.
"""
test_cases = [
("What a nice weather!", "I'm in the UK and it's raining here."),
("Nice to meet you!", "Hello! I'm from the future!"),
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': False,
'batchsize': 1,
}
dialogpt = create_agent(opt)
for text, label in test_cases:
dialogpt.observe({'text': text, 'episode_done': True})
response = dialogpt.act()
assert response['text'] == label

@testing_utils.retry(ntries=3, log_retry=True)
def test_dialogpt(self):
"""
Checks that DialoGPT gets a certain performance on the integration test task.
"""
valid, test = testing_utils.train_model(
dict(
task='integration_tests:overfit',
Expand Down
2 changes: 2 additions & 0 deletions tests/nightly/gpu/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def test_custom_special_tokens(self):


class TestGpt2(unittest.TestCase):
# Did you implement a test for DialoGPT too if your changes affect it?

def _test_batchsize(self, batchsize, add_start_token):
utterances = [
'Just keep swimming -',
Expand Down

0 comments on commit a2adbe1

Please sign in to comment.