diff --git a/src/transformers/models/cohere/configuration_cohere.py b/src/transformers/models/cohere/configuration_cohere.py index 73973bfad60b93..3c1237e5113789 100644 --- a/src/transformers/models/cohere/configuration_cohere.py +++ b/src/transformers/models/cohere/configuration_cohere.py @@ -20,6 +20,7 @@ """Cohere model configuration""" from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation from ...utils import logging @@ -79,6 +80,43 @@ class CohereConfig(PretrainedConfig): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -121,6 +159,7 @@ def __init__( eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, + rope_scaling=None, attention_bias=False, attention_dropout=0.0, use_qk_norm=False, @@ -144,10 +183,14 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta + self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.use_qk_norm = use_qk_norm + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 4010d9ec3a4327..ae84a9ec2d1a43 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -37,6 +37,7 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -135,35 +136,97 @@ def forward(self, hidden_states): class CohereRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + # Note: the forward pass of this RoPE is slightly different from Llama's, resulting in different `sin`/`cos` for + # the same parameterization. The differences are highlighted with a comment. + + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[CohereConfig] = None, + ): super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`CohereRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.repeat_interleave(freqs, 2, dim=-1) + emb = torch.repeat_interleave(freqs, 2, dim=-1) # This line differs from Llama's implementation cos = emb.cos() sin = emb.sin() - return cos, sin + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): - # Split and rotate + # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) @@ -272,17 +335,10 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - self._init_rope() - # Ignore copy - def _init_rope(self): - self.rotary_emb = CohereRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = CohereRotaryEmbedding(config=self.config) - # Ignore copy def forward( self, hidden_states: torch.Tensor, @@ -292,6 +348,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -310,7 +367,16 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -350,8 +416,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere -# TODO(joao): add me back asap :) +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere class CohereFlashAttention2(CohereAttention): """ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays @@ -377,6 +442,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -402,7 +468,16 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -418,7 +493,6 @@ def forward( dropout_rate = self.attention_dropout if self.training else 0.0 - # Ignore copy # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. @@ -465,8 +539,6 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention Llama->Cohere -# TODO(joao): add me back asap :) class CohereSdpaAttention(CohereAttention): """ Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -474,7 +546,6 @@ class CohereSdpaAttention(CohereAttention): SDPA API. """ - # Ignore copy def forward( self, hidden_states: torch.Tensor, @@ -484,6 +555,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -517,7 +589,16 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -587,6 +668,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -601,6 +683,11 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. """ residual = hidden_states @@ -615,6 +702,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) # Fully Connected @@ -755,8 +843,7 @@ def _init_weights(self, module): "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, ) -# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere -# TODO(joao): add me back asap :) +# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE class CohereModel(CoherePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`] @@ -776,6 +863,7 @@ def __init__(self, config: CohereConfig): [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + self.rotary_emb = CohereRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -787,14 +875,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Ignore copy @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -823,30 +910,33 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 return_legacy_cache = False if ( use_cache and not isinstance(past_key_values, Cache) and not self.training ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - - # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -866,6 +956,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -876,6 +967,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 8db9f6e8b7d09f..43bac44ba1be20 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1066,7 +1066,7 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 085751cd9bc039..b14e0a4b3d8ca5 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -862,7 +862,7 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index ff10b6e6d875f9..876f5ed2a7c8da 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -839,7 +839,7 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 435f0091e06e70..a3667e06534564 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -192,7 +192,7 @@ def __init__( self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. + # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9a1d6c0749f932..c7017832b9324c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -951,7 +951,7 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c43418182c3881..ffe16b27203301 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -767,7 +767,7 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2e23d06699087e..c7062e75b1085c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1023,7 +1023,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 007e69570e7821..b4bda8e2db5251 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -873,7 +873,7 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index a53f1eeda61196..a33338365312db 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -1012,7 +1012,7 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 9fab09bdcc7877..ccaa2c7fd29aae 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -690,7 +690,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0d8be04af20d5c..648d1653a3b503 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -981,7 +981,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 273b6a8f505e79..ec395679ae6207 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1008,7 +1008,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 030c74b034b794..d0ea8ef0e376e0 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -920,7 +920,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index b196ed72a49b23..6f483e50cde065 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1084,7 +1084,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 27d0c856a61bd6..d91c0832ed33da 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -965,7 +965,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" ) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c359c07c69c0b8..0be37c4e1fb91c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -894,7 +894,7 @@ def forward( use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" )