diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index fb579be19e..b5e7c48b98 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,13 +75,22 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2", "gptj"]: + if self.model.config.model_type in [ + "llama", + "mistral", + "falcon", + "phi", + "mixtral", + "qwen2", + "gptj", + "starcoder2", + ]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, } ) - if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon"]: + if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon", "starcoder2"]: if self.model.config.model_type != "falcon": self.model_inputs.update( { diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 5b247f45f4..3ad2cc2352 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -381,7 +381,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module - if model.config.model_type in ["llama", "falcon", "qwen2"]: + if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2"]: patch_scoped_linear_all_reduce(model) if args.quant_config: diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 2ae3f9a65c..e22fe6facc 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -94,6 +94,7 @@ "starcoder2", "persimmon", "qwen2", + "starcoder2", "llava", "llava_next", "stablelm", @@ -435,7 +436,7 @@ def create_pad_arg(pad_amount, i, j): else: assert False elif model_kwargs["past_key_values"][0][0].dim() == 4: - return (0, 0, 0, pad_amount) # llama, falcon, qwen2 + return (0, 0, 0, pad_amount) # llama, falcon, qwen2, starcoder2 else: assert False, "Unknown case, please handle, or dont use bucketing" @@ -860,7 +861,8 @@ def generate( "phi", "qwen2", "gptj", - ], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2 and gptj at the moment" + "starcoder2", + ], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2 and starcoder2 at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 @@ -1016,7 +1018,7 @@ def generate( model_kwargs["kv_cache_len"] = calculated_max_length model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens - if self.config.model_type in ["llama", "falcon", "mistral", "qwen2", "gptj"]: + if self.config.model_type in ["llama", "falcon", "mistral", "qwen2", "gptj", "starcoder2"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 5d6e4f149b..5e8c390a88 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -88,8 +88,10 @@ GaudiQwen2Model, GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, + GaudiStarcoder2Attention, GaudiStarcoder2DecoderLayer, GaudiStarcoder2ForCausalLM, + GaudiStarcoder2Model, LlamaConfig, MistralConfig, MixtralConfig, @@ -175,8 +177,6 @@ gaudi_SpeechT5DecoderLayer_forward, gaudi_stablelm_attention_forward, gaudi_stablelm_model_forward, - gaudi_starcoder2_attention_forward, - gaudi_starcoder2_model_forward, gaudi_swin_get_attn_mask, gaudi_t5_layernorm_forward, gaudi_T5Attention_forward, @@ -517,8 +517,8 @@ def adapt_transformers_to_gaudi(): # Optimization for starcoder2 on Gaudi transformers.models.starcoder2.modeling_starcoder2.Starcoder2ForCausalLM = GaudiStarcoder2ForCausalLM - transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model.forward = gaudi_starcoder2_model_forward - transformers.models.starcoder2.modeling_starcoder2.Starcoder2Attention.forward = gaudi_starcoder2_attention_forward + transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model = GaudiStarcoder2Model + transformers.models.starcoder2.modeling_starcoder2.Starcoder2Attention = GaudiStarcoder2Attention transformers.models.starcoder2.modeling_starcoder2.Starcoder2DecoderLayer = GaudiStarcoder2DecoderLayer # Optimization for qwen2 on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 0826881146..5fe87144bd 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -191,10 +191,10 @@ gaudi_stablelm_model_forward, ) from .starcoder2 import ( + GaudiStarcoder2Attention, GaudiStarcoder2DecoderLayer, GaudiStarcoder2ForCausalLM, - gaudi_starcoder2_attention_forward, - gaudi_starcoder2_model_forward, + GaudiStarcoder2Model, ) from .swin import gaudi_swin_get_attn_mask from .t5 import ( diff --git a/optimum/habana/transformers/models/starcoder2/__init__.py b/optimum/habana/transformers/models/starcoder2/__init__.py index c4ac4dfc22..749e2a6943 100644 --- a/optimum/habana/transformers/models/starcoder2/__init__.py +++ b/optimum/habana/transformers/models/starcoder2/__init__.py @@ -1,6 +1,6 @@ from .modeling_starcoder2 import ( + GaudiStarcoder2Attention, GaudiStarcoder2DecoderLayer, GaudiStarcoder2ForCausalLM, - gaudi_starcoder2_attention_forward, - gaudi_starcoder2_model_forward, + GaudiStarcoder2Model, ) diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index cd8f986765..4ca9793dd3 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -1,19 +1,38 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### +# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company +############################################################################### + import math import warnings from typing import List, Optional, Tuple, Union import torch -from torch import nn -from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +import torch.utils.checkpoint from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.starcoder2.configuration_starcoder2 import Starcoder2Config from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, - Starcoder2Config, Starcoder2DecoderLayer, Starcoder2ForCausalLM, + Starcoder2MLP, + Starcoder2Model, apply_rotary_pos_emb, - repeat_kv, ) from transformers.utils import logging @@ -31,147 +50,352 @@ try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: - print("Not using HPU fused sdpa kernel ") + print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +import habana_frameworks.torch.core as htcore + + logger = logging.get_logger(__name__) -def gaudi_starcoder2_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Copied from Starcoder2Attention.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py - The only differences are: - - add new args token_idx - - optimize KV cache - - add new args use_flash_attention - - add new arg flash_attention_recompute - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - norm_factor = 1.0 / math.sqrt(self.head_dim) +class GaudiStarcoder2MLP(Starcoder2MLP): + def pre_mlp_forward(self, x): + inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + output = self.down_proj(inputs) + return output - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + def mlp_all_reduce(self, x): + if hasattr(self.down_proj, "all_reduce"): + self.down_proj.all_reduce(x) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + def post_mlp_forward(self, x): + if hasattr(self.down_proj, "post_all_reduce"): + return self.down_proj.post_all_reduce(x) + return x - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0: - # When token_idx is used, static seq len = (input token len + max output token len) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + +def gaudi_starcoder2_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask + + +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training) - - if past_key_value is not None: - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev else: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + return torch.cat((prev, cur), dim=dim) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * norm_factor + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" +class GaudiStarcoder2Attention(Starcoder2Attention): + def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.block_size = 4096 + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + + def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size): + """ + Gaudi version of Flash Attention V1 to support long sequence at prompt phase + Causal mask is not supported in this optimization + """ + q_len = query_layer.size(-2) + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + q_padding = q_tiles * q_block_size - q_len + query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) + if attention_mask is not None: + attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) + + row_o_list = [] + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = query_layer[:, :, s:e, :] + row_mask = attention_mask[:, :, s:e, :] + attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None) + row_o_list.append(attn_output_partial) + attn_output = torch.cat(row_o_list, dim=-2) + + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] + + return attn_output + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - attn_weights = attn_weights + attention_mask + bsz, q_len, _ = hidden_states.size() - query_length = q_len if past_key_value is None else q_len + past_key_value.key_cache[self.layer_idx].shape[2] - # Taken from mpt: https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/mpt/modeling_mpt.py - if use_flash_attention and FusedSDPA: - import habana_frameworks.torch.hpu as ht + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - if query_length == 1: - # next token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) - else: - # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - if query_length > 16384: - attn_output = self.gaudi_flash_attn_v1( - query_states, key_states, value_states, attention_mask, 0.0, self.block_size - ) - ht.mark_step() + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, self.training + ) + + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None + + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht + + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) - attn_weights = None - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + if q_len > 16384: + attn_output = self.gaudi_flash_attn_v1( + query_states, key_states, value_states, attention_mask, 0.0, self.block_size + ) + htcore.mark_step() + else: + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + + else: + query_states, key_states, value_states, attention_mask = gaudi_starcoder2_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) + if not output_attentions: + attn_weights = None - if not output_attentions: - attn_weights = None + return attn_output, attn_weights, past_key_value - return attn_output, attn_weights, past_key_value + def attention_all_reduce(self, attn_output): + if hasattr(self.o_proj, "all_reduce"): + self.o_proj.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.o_proj, "post_all_reduce"): + self.o_proj.post_all_reduce(attn_output) + return attn_output class GaudiStarcoder2DecoderLayer(Starcoder2DecoderLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__(config, layer_idx) - self.self_attn = Starcoder2Attention(config, layer_idx) + self.hidden_size = config.hidden_size + self.self_attn = GaudiStarcoder2Attention(config, layer_idx) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) def forward( self, @@ -181,34 +405,31 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, - **kwargs, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from Starcoder2DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py - The only differences are: - - add new args token_idx - - add new args use_flash_attention - - add new arg flash_attention_recompute - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, ) hidden_states = residual + hidden_states @@ -228,159 +449,257 @@ def forward( return outputs + def pre_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = self.input_layernorm(hidden_states) + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + cache_idx=cache_idx, + ) + return hidden_states, attn_weights, present_key_value -def gaudi_starcoder2_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, -) -> Union[Tuple, BaseModelOutputWithPast]: - """ - Copied from Starcoder2Model.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py - The only differences are: - - add new args token_idx - - replace _prepare_4d_causal_attention_mask with _gaudi_prepare_4d_causal_attention_mask - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual - past_key_values_length = 0 + hidden_states = self.post_attention_layernorm(hidden_states) - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if token_idx is None: - past_key_values_length = past_key_values.get_usable_length(seq_length) + hidden_states = self.mlp.pre_mlp_forward(hidden_states) + return hidden_states, residual - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + return hidden_states - # 4d mask is passed through the layers - attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - hidden_states = inputs_embeds - hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) +class GaudiStarcoder2Model(Starcoder2Model): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - None, - use_flash_attention, - flash_attention_recompute, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - hidden_states = layer_outputs[0] + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + self._attn_implementation = "eager" - if output_attentions: - all_self_attns += (layer_outputs[1],) + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - hidden_states = self.norm(hidden_states) + use_new_cache = False # Ignoring new Cache path for HPU + past_seen_tokens = 0 - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + if past_key_values is not None and use_cache: # kept for BC (cache positions) + if reuse_cache: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_usable_length(seq_length) + else: + past_seen_tokens = past_key_values[0][0].shape[2] - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + # HPU specific mask generation + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if not use_new_cache else None + + if lazy_mode: + htcore.mark_step() + + for layer_idx, decoder_layer in enumerate(self.layers): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + None, + attn_softmax_bf16, + False, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) class GaudiStarcoder2ForCausalLM(Starcoder2ForCausalLM): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + def forward( self, input_ids: torch.LongTensor = None, @@ -393,21 +712,28 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + trim_logits: Optional[bool] = False, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: - """ - Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py - The only differences are: - - add new args token_idx - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not hasattr(self.config, "_attn_implementation"): + setattr(self.config, "_attn_implementation", "eager") + else: + self.config._attn_implementation = "eager" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -419,14 +745,26 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] + + logits = self.lm_head(hidden_states).float() loss = None if labels is not None: @@ -434,11 +772,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device + # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: @@ -454,20 +792,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs ): - """ - Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py - The only differences are: - - add new args token_idx - - add token_idx into model_inputs - - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx - """ - token_idx = kwargs.get("token_idx", None) - # Omit tokens covered by past_key_values + past_length = 0 + + reuse_cache = kwargs.get("reuse_cache") if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -495,8 +828,10 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -509,6 +844,8 @@ def prepare_inputs_for_generation( else: position_ids = position_ids[:, -input_ids.shape[1] :] + cache_position = None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -517,11 +854,20 @@ def prepare_inputs_for_generation( model_inputs.update( { - "position_ids": position_ids, + "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "trim_logits": kwargs.get("trim_logits"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 5966790d48..83e77460d7 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -955,9 +955,9 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - # attn_softmax_bf16 and use_flash_attention is enabled only for llama and qwen2 + # attn_softmax_bf16 and use_flash_attention is enabled only for llama, qwen2 and starcoder2 if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type in ["llama", "qwen2"]: + if self.model.config.model_type in ["llama", "qwen2", "starcoder2"]: if self.model.generation_config.attn_softmax_bf16: inputs["attn_softmax_bf16"] = True if self.model.generation_config.use_flash_attention: @@ -1835,9 +1835,9 @@ def evaluation_loop( if batch_size is None: batch_size = observed_batch_size - # attn_softmax_bf16 and use_flash_attention are enabled only for llama and qwen2 + # attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2 and starcoder2 if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type in ["llama", "qwen2"]: + if self.model.config.model_type in ["llama", "qwen2", "starcoder2"]: if self.model.generation_config.attn_softmax_bf16: inputs["attn_softmax_bf16"] = True if self.model.generation_config.use_flash_attention: