Skip to content

Commit

Permalink
Fix Seq2Seq generation for XModelWithHeads classes (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Feb 9, 2022
1 parent 7412ef8 commit 131b12e
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 25 deletions.
63 changes: 58 additions & 5 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,12 @@ def __init__(self, config: BartConfig, **kwargs):

self._init_head_modules()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
Expand Down Expand Up @@ -1315,11 +1321,14 @@ def forward(
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
Expand All @@ -1332,6 +1341,50 @@ def forward(

return head_outputs

# Copied from BartForConditionalGeneration
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

# Copied from BartForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

# Copied from BartForConditionalGeneration
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past


@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
Expand Down
63 changes: 58 additions & 5 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,12 @@ def __init__(self, config: MBartConfig, **kwargs):

self._init_head_modules()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
Expand Down Expand Up @@ -1315,11 +1321,14 @@ def forward(
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
Expand All @@ -1332,6 +1341,50 @@ def forward(

return head_outputs

# Copied from MBartForConditionalGeneration
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

# Copied from MBartForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)

# Copied from MBartForConditionalGeneration
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past


@add_start_docstrings(
"The MBART Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING
Expand Down
64 changes: 64 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,12 @@ def __init__(self, config):
self.model_parallel = False
self.device_map = None

def get_encoder(self):
return self.transformer.encoder

def get_decoder(self):
return self.transformer.decoder

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1944,3 +1950,61 @@ def forward(
return head_outputs
else:
return model_output

# Copied from T5ForConditionalGeneration
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}

# Copied from T5ForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)

# Copied from T5ForConditionalGeneration
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past

reordered_decoder_past = ()
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)

assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)

reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
62 changes: 47 additions & 15 deletions tests/test_adapter_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_tagging_head(self):
label_dict = {}
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length, 2), label_dict=label_dict
model1, model2, "dummy", output_shape=(self.batch_size, self.seq_length, 2), label_dict=label_dict
)

def test_qa_head(self):
Expand All @@ -99,30 +99,58 @@ def test_qa_head(self):
label_dict["start_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device)
label_dict["end_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device)
self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length), label_dict=label_dict
model1, model2, "dummy", output_shape=(self.batch_size, self.seq_length), label_dict=label_dict
)

def test_causal_or_seq2seq_lm_head(self):
def test_causal_lm_head(self):
if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_causal_lm_head"):
if hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"):
seq2seq_head = True
else:
self.skipTest("No causal or seq2seq language model head")
else:
seq2seq_head = False
self.skipTest("No causal language model head")

model1, model2 = create_twin_models(AutoModelWithHeads, self.config)
model1.add_causal_lm_head("dummy")

if seq2seq_head:
model1.add_seq2seq_lm_head("dummy")
else:
model1.add_causal_lm_head("dummy")
label_dict = {}
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)

self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict
model1,
model2,
"dummy",
output_shape=(self.batch_size, self.seq_length, model1.config.vocab_size),
label_dict=label_dict,
)

def test_seq2seq_lm_head(self):
if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"):
self.skipTest("No seq2seq language model head")

model1, model2 = create_twin_models(AutoModelWithHeads, self.config)
model1.add_seq2seq_lm_head("dummy")

label_dict = {}
# Use a different length for the seq2seq output
seq_output_length = 32
label_dict["labels"] = torch.zeros((self.batch_size, seq_output_length), dtype=torch.long, device=torch_device)

# prepare decoder_input_ids similar to how DataCollatorForSeq2Seq does it
if hasattr(model1, "prepare_decoder_input_ids_from_labels"):
decoder_input_ids = model1.prepare_decoder_input_ids_from_labels(labels=label_dict["labels"])
label_dict["decoder_input_ids"] = decoder_input_ids

self.run_prediction_head_test(
model1,
model2,
"dummy",
output_shape=(self.batch_size, seq_output_length, model1.config.vocab_size),
label_dict=label_dict,
)

# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertEqual(generated.shape, (1, seq_output_length))

def test_masked_lm_head(self):
if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_masked_lm_head"):
self.skipTest("No causal or seq2seq language model head")
Expand All @@ -133,7 +161,11 @@ def test_masked_lm_head(self):
label_dict = {}
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict
model1,
model2,
"dummy",
output_shape=(self.batch_size, self.seq_length, model1.config.vocab_size),
label_dict=label_dict,
)

def test_dependency_parsing_head(self):
Expand Down

0 comments on commit 131b12e

Please sign in to comment.