Skip to content

Commit

Permalink
Fix GPT2DoubleHeadsModel to work with model.generate() (huggingface#6601
Browse files Browse the repository at this point in the history
)

* Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used

and for GPT2LMHeadModel too

* Update tests to check token_type_ids usage in GPT2 models
  • Loading branch information
LSinev authored and Zhylkaaa committed Nov 17, 2020
1 parent 10f7239 commit 0946eb7
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 2 deletions.
9 changes: 9 additions & 0 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def _expand_inputs_for_generation(
)
input_ids = input_ids.index_select(0, expanded_return_idx)

if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

Expand Down Expand Up @@ -194,6 +198,11 @@ def _update_model_kwargs_for_generation(
else:
model_kwargs["past"] = None

# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

# update attention mask
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,12 @@ def get_output_embeddings(self):
return self.lm_head

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
Expand All @@ -729,6 +732,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -836,14 +840,32 @@ def get_output_embeddings(self):
return self.lm_head

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None

return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
Expand Down
92 changes: 90 additions & 2 deletions tests/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,84 @@ def test_batch_generation(self):
]

inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
token_type_ids = torch.cat(
[
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
input_ids.new_full((input_ids.shape[0], 1), 500),
],
dim=-1,
)

outputs = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
)

outputs_tt = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
token_type_ids=token_type_ids,
)

inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(input_ids=inputs_non_padded)

num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)

batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)

expected_output_sentence = [
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
"Today, I'm going to be doing a lot of research on this. I",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])

@slow
def test_batch_generation_2heads(self):
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
model.to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.padding_side = "left"

# This tokenizer has no pad token, so we have to set it in some way
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]

inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
token_type_ids = torch.cat(
[
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
input_ids.new_full((input_ids.shape[0], 1), 500),
],
dim=-1,
)

outputs = model.generate(
input_ids=inputs["input_ids"].to(torch_device),
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
)

outputs_tt = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
token_type_ids=token_type_ids,
)

inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
Expand All @@ -483,6 +557,7 @@ def test_batch_generation(self):
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)

batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)

Expand All @@ -491,6 +566,7 @@ def test_batch_generation(self):
"Today, I'm going to be doing a lot of research on this. I",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])

@slow
Expand Down Expand Up @@ -540,11 +616,23 @@ def test_gpt2_sample(self):
model.to(torch_device)

torch.manual_seed(0)
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

token_type_ids = tokenized.token_type_ids.to(torch_device)
output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
output_seq_tt = model.generate(
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
)
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)

EXPECTED_OUTPUT_STR = (
"Today is a nice day and if you don't know anything about the state of play during your holiday"
)
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
) # token_type_ids should change output

0 comments on commit 0946eb7

Please sign in to comment.