Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Seq2Seq generation for XModelWithHeads classes #275

Merged
merged 2 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,12 @@ def __init__(self, config: BartConfig, **kwargs):

self._init_head_modules()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
Expand Down Expand Up @@ -1314,11 +1320,14 @@ def forward(
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
Expand All @@ -1331,6 +1340,50 @@ def forward(

return head_outputs

# Copied from BartForConditionalGeneration
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

# Copied from BartForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

# Copied from BartForConditionalGeneration
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past


@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
Expand Down
63 changes: 58 additions & 5 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,12 @@ def __init__(self, config: MBartConfig, **kwargs):

self._init_head_modules()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
Expand Down Expand Up @@ -1314,11 +1320,14 @@ def forward(
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
Expand All @@ -1331,6 +1340,50 @@ def forward(

return head_outputs

# Copied from MBartForConditionalGeneration
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

# Copied from MBartForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)

# Copied from MBartForConditionalGeneration
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past


@add_start_docstrings(
"The MBART Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING
Expand Down
64 changes: 64 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,12 @@ def __init__(self, config):
self.model_parallel = False
self.device_map = None

def get_encoder(self):
return self.transformer.encoder

def get_decoder(self):
return self.transformer.decoder

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1943,3 +1949,61 @@ def forward(
return head_outputs
else:
return model_output

# Copied from T5ForConditionalGeneration
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}

# Copied from T5ForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)

# Copied from T5ForConditionalGeneration
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past

reordered_decoder_past = ()
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)

assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)

reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
11 changes: 5 additions & 6 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,11 @@ class T5AdapterTestBase(AdapterTestBase):
config = make_config(
T5Config,
d_model=16,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=4,
decoder_attention_heads=4,
encoder_ffn_dim=4,
decoder_ffn_dim=4,
num_layers=2,
num_decoder_layers=2,
num_heads=4,
d_ff=4,
d_kv=16 // 4,
tie_word_embeddings=False,
decoder_start_token_id=0,
)
Expand Down
62 changes: 47 additions & 15 deletions tests/test_adapter_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_tagging_head(self):
label_dict = {}
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length, 2), label_dict=label_dict
model1, model2, "dummy", output_shape=(self.batch_size, self.seq_length, 2), label_dict=label_dict
)

def test_qa_head(self):
Expand All @@ -99,30 +99,58 @@ def test_qa_head(self):
label_dict["start_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device)
label_dict["end_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device)
self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length), label_dict=label_dict
model1, model2, "dummy", output_shape=(self.batch_size, self.seq_length), label_dict=label_dict
)

def test_causal_or_seq2seq_lm_head(self):
def test_causal_lm_head(self):
if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_causal_lm_head"):
if hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"):
seq2seq_head = True
else:
self.skipTest("No causal or seq2seq language model head")
else:
seq2seq_head = False
self.skipTest("No causal language model head")

model1, model2 = create_twin_models(AutoModelWithHeads, self.config)
model1.add_causal_lm_head("dummy")

if seq2seq_head:
model1.add_seq2seq_lm_head("dummy")
else:
model1.add_causal_lm_head("dummy")
label_dict = {}
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)

self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict
model1,
model2,
"dummy",
output_shape=(self.batch_size, self.seq_length, model1.config.vocab_size),
label_dict=label_dict,
)

def test_seq2seq_lm_head(self):
if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"):
self.skipTest("No seq2seq language model head")

model1, model2 = create_twin_models(AutoModelWithHeads, self.config)
model1.add_seq2seq_lm_head("dummy")

label_dict = {}
# Use a different length for the seq2seq output
seq_output_length = 32
label_dict["labels"] = torch.zeros((self.batch_size, seq_output_length), dtype=torch.long, device=torch_device)

# prepare decoder_input_ids similar to how DataCollatorForSeq2Seq does it
if hasattr(model1, "prepare_decoder_input_ids_from_labels"):
decoder_input_ids = model1.prepare_decoder_input_ids_from_labels(labels=label_dict["labels"])
label_dict["decoder_input_ids"] = decoder_input_ids

self.run_prediction_head_test(
model1,
model2,
"dummy",
output_shape=(self.batch_size, seq_output_length, model1.config.vocab_size),
label_dict=label_dict,
)

# Finally, also check if generation works properly
input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"]
input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertEqual(generated.shape, (1, seq_output_length))

def test_masked_lm_head(self):
if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_masked_lm_head"):
self.skipTest("No causal or seq2seq language model head")
Expand All @@ -133,7 +161,11 @@ def test_masked_lm_head(self):
label_dict = {}
label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
self.run_prediction_head_test(
model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict
model1,
model2,
"dummy",
output_shape=(self.batch_size, self.seq_length, model1.config.vocab_size),
label_dict=label_dict,
)

def test_dependency_parsing_head(self):
Expand Down