diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c60ec7d697..89fe016ff1 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1260,6 +1260,12 @@ def __init__(self, config: BartConfig, **kwargs): self._init_head_modules() + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1314,11 +1320,14 @@ def forward( ) # sequence classification based on last token in sequence x = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) - (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) - if len(torch.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + if input_ids is not None and x.shape == input_ids.shape: + eos_mask = input_ids.eq(self.config.eos_token_id) + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + else: + cls_representation = x head_outputs = self.forward_head( outputs, @@ -1331,6 +1340,50 @@ def forward( return head_outputs + # Copied from BartForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + # Copied from BartForConditionalGeneration + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + # Copied from BartForConditionalGeneration + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 3175428b64..71003e99ca 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1260,6 +1260,12 @@ def __init__(self, config: MBartConfig, **kwargs): self._init_head_modules() + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1314,11 +1320,14 @@ def forward( ) # sequence classification based on last token in sequence x = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) - (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) - if len(torch.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + if input_ids is not None and x.shape == input_ids.shape: + eos_mask = input_ids.eq(self.config.eos_token_id) + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + else: + cls_representation = x head_outputs = self.forward_head( outputs, @@ -1331,6 +1340,50 @@ def forward( return head_outputs + # Copied from MBartForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + # Copied from MBartForConditionalGeneration + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + # Copied from MBartForConditionalGeneration + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + @add_start_docstrings( "The MBART Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 859aaf5ec4..9874b7b01d 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1877,6 +1877,12 @@ def __init__(self, config): self.model_parallel = False self.device_map = None + def get_encoder(self): + return self.transformer.encoder + + def get_decoder(self): + return self.transformer.decoder + def forward( self, input_ids=None, @@ -1943,3 +1949,61 @@ def forward( return head_outputs else: return model_output + + # Copied from T5ForConditionalGeneration + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + # Copied from T5ForConditionalGeneration + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + # Copied from T5ForConditionalGeneration + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 36e225411f..e09fefc41b 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -370,12 +370,11 @@ class T5AdapterTestBase(AdapterTestBase): config = make_config( T5Config, d_model=16, - encoder_layers=2, - decoder_layers=2, - encoder_attention_heads=4, - decoder_attention_heads=4, - encoder_ffn_dim=4, - decoder_ffn_dim=4, + num_layers=2, + num_decoder_layers=2, + num_heads=4, + d_ff=4, + d_kv=16 // 4, tie_word_embeddings=False, decoder_start_token_id=0, ) diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index 3dee3d9bb1..3ed4f1485f 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -85,7 +85,7 @@ def test_tagging_head(self): label_dict = {} label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) self.run_prediction_head_test( - model1, model2, "dummy", output_shape=(1, self.seq_length, 2), label_dict=label_dict + model1, model2, "dummy", output_shape=(self.batch_size, self.seq_length, 2), label_dict=label_dict ) def test_qa_head(self): @@ -99,30 +99,58 @@ def test_qa_head(self): label_dict["start_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) label_dict["end_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) self.run_prediction_head_test( - model1, model2, "dummy", output_shape=(1, self.seq_length), label_dict=label_dict + model1, model2, "dummy", output_shape=(self.batch_size, self.seq_length), label_dict=label_dict ) - def test_causal_or_seq2seq_lm_head(self): + def test_causal_lm_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_causal_lm_head"): - if hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"): - seq2seq_head = True - else: - self.skipTest("No causal or seq2seq language model head") - else: - seq2seq_head = False + self.skipTest("No causal language model head") model1, model2 = create_twin_models(AutoModelWithHeads, self.config) + model1.add_causal_lm_head("dummy") - if seq2seq_head: - model1.add_seq2seq_lm_head("dummy") - else: - model1.add_causal_lm_head("dummy") label_dict = {} label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + self.run_prediction_head_test( - model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict + model1, + model2, + "dummy", + output_shape=(self.batch_size, self.seq_length, model1.config.vocab_size), + label_dict=label_dict, ) + def test_seq2seq_lm_head(self): + if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"): + self.skipTest("No seq2seq language model head") + + model1, model2 = create_twin_models(AutoModelWithHeads, self.config) + model1.add_seq2seq_lm_head("dummy") + + label_dict = {} + # Use a different length for the seq2seq output + seq_output_length = 32 + label_dict["labels"] = torch.zeros((self.batch_size, seq_output_length), dtype=torch.long, device=torch_device) + + # prepare decoder_input_ids similar to how DataCollatorForSeq2Seq does it + if hasattr(model1, "prepare_decoder_input_ids_from_labels"): + decoder_input_ids = model1.prepare_decoder_input_ids_from_labels(labels=label_dict["labels"]) + label_dict["decoder_input_ids"] = decoder_input_ids + + self.run_prediction_head_test( + model1, + model2, + "dummy", + output_shape=(self.batch_size, seq_output_length, model1.config.vocab_size), + label_dict=label_dict, + ) + + # Finally, also check if generation works properly + input_ids = self.get_input_samples((1, self.seq_length), config=model1.config)["input_ids"] + input_ids = input_ids.to(torch_device) + generated = model1.generate(input_ids, max_length=seq_output_length) + self.assertEqual(generated.shape, (1, seq_output_length)) + def test_masked_lm_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_masked_lm_head"): self.skipTest("No causal or seq2seq language model head") @@ -133,7 +161,11 @@ def test_masked_lm_head(self): label_dict = {} label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) self.run_prediction_head_test( - model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict + model1, + model2, + "dummy", + output_shape=(self.batch_size, self.seq_length, model1.config.vocab_size), + label_dict=label_dict, ) def test_dependency_parsing_head(self):