From ca51a941bb7ce961657e28ad0e8ab567bba4e933 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 27 Mar 2023 15:48:23 +0100 Subject: [PATCH] Generate: support for left-padding on GPTNeoX and Llama (#22382) --- .../models/gpt_neox/modeling_gpt_neox.py | 80 +++++++--- src/transformers/models/gptj/modeling_gptj.py | 2 +- .../models/llama/modeling_llama.py | 148 ++++++------------ .../models/gpt_neox/test_modeling_gpt_neox.py | 2 +- 4 files changed, 108 insertions(+), 124 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 62ad98972df0b0..cb132e721a0a5c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -100,12 +100,13 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask, - head_mask=None, - layer_past=None, - use_cache=False, - output_attentions=False, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, ): has_layer_past = layer_past is not None @@ -132,12 +133,10 @@ def forward( # Compute token offset for rotary embeddings (when decoding) seq_len = key.shape[-2] - offset = 0 if has_layer_past: - offset = layer_past[0].shape[-2] - seq_len += offset + seq_len += layer_past[0].shape[-2] cos, sin = self.rotary_emb(value, seq_len=seq_len) - query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query = torch.cat((query, query_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1) @@ -275,9 +274,11 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): - cos = cos[..., offset : q.shape[-2] + offset, :] - sin = sin[..., offset : q.shape[-2] + offset, :] +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -308,16 +309,18 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - use_cache=False, - layer_past=None, - output_attentions=False, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + layer_past: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, ): attention_layer_outputs = self.attention( self.input_layernorm(hidden_states), attention_mask=attention_mask, + position_ids=position_ids, layer_past=layer_past, head_mask=head_mask, use_cache=use_cache, @@ -374,6 +377,11 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: @@ -430,6 +438,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, @@ -467,7 +476,17 @@ def forward( batch_size, seq_length = input_shape if past_key_values is None: + past_length = 0 past_key_values = tuple([None] * self.config.num_hidden_layers) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() # Attention mask. if attention_mask is not None: @@ -527,12 +546,14 @@ def custom_forward(*inputs): create_custom_forward(layer), hidden_states, attention_mask, + position_ids, head_mask[i], ) else: outputs = layer( hidden_states, attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask[i], layer_past=layer_past, use_cache=use_cache, @@ -587,6 +608,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, @@ -640,6 +662,7 @@ def forward( outputs = self.gpt_neox( input_ids, attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, @@ -672,20 +695,29 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used if past_key_values and past_key_values[0] is not None: input_ids = input_ids[:, -1:] + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + return { "input_ids": input_ids, "attention_mask": attention_mask, + "position_ids": position_ids, "past_key_values": past_key_values, } diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 6cd756bc1a33f0..c5e6cc564f0922 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -192,7 +192,7 @@ def _get_embed_positions(self, position_ids): def forward( self, - hidden_states: Optional[torch.FloatTensor], + hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f341eafb51b2b5..cc9e01e4aa5aeb 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -38,6 +38,7 @@ _CONFIG_FOR_DOC = "LlamaConfig" +# Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. @@ -53,6 +54,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) +# Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. @@ -126,9 +128,11 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): - cos = cos[..., offset : q.shape[-2] + offset, :] - sin = sin[..., offset : q.shape[-2] + offset, :] +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -197,13 +201,12 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -211,12 +214,10 @@ def forward( value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - offset = 0 if past_key_value is not None: - offset = past_key_value[0].shape[-2] - kv_seq_len += offset + kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] if past_key_value is not None: @@ -283,9 +284,10 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -308,8 +310,9 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) @@ -406,7 +409,11 @@ def _set_gradient_checkpointing(self, module, value=False): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape @@ -488,10 +495,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em return combined_attention_mask + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -499,49 +508,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -559,11 +525,23 @@ def forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + seq_length_with_past = seq_length past_key_values_length = 0 + if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions @@ -608,12 +586,14 @@ def custom_forward(*inputs): create_custom_forward(decoder_layer), hidden_states, attention_mask, + position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, + position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -674,11 +654,13 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -689,52 +671,10 @@ def forward( ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional - tensors are only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: @@ -765,6 +705,7 @@ def forward( outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, + position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -807,6 +748,14 @@ def prepare_inputs_for_generation( if past_key_values: input_ids = input_ids[:, -1:] + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -815,6 +764,7 @@ def prepare_inputs_for_generation( model_inputs.update( { + "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, @@ -868,6 +818,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -886,8 +837,9 @@ def forward( transformer_outputs = self.model( input_ids, - past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 1798f01358ff8c..519b10a040ff30 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -237,7 +237,7 @@ def test_feed_forward_chunking(self): @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): @slow - def test_lm_generate_codegen(self): + def test_lm_generate_gptneox(self): tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped") for checkpointing in [True, False]: model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")