diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 87d24c6cf66351..2ed663b26256ed 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -4,6 +4,10 @@ import torch from .configuration_utils import PretrainedConfig +from .utils import logging + + +logger = logging.get_logger(__name__) @dataclass @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + class DynamicCache(Cache): """ @@ -69,7 +84,7 @@ class DynamicCache(Cache): def __init__(self) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ @@ -121,7 +136,7 @@ def update( """ # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None: self.window_length = window_length self.num_sink_tokens = num_sink_tokens self.cos_sin_cache = {} - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @staticmethod def _rotate_half(x): @@ -272,7 +287,7 @@ def update( # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: @@ -398,16 +413,11 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) - - def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after + # https://github.com/pytorch/pytorch/issues/120248 is fixed + return (self.key_cache[0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 68f1b45ae0a0df..a0c58749c33743 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -633,7 +633,6 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, - model_inputs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation( dim=-1, ) - model_kwargs["cache_position"] = model_inputs.get("cache_position", None) + if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 return model_kwargs @@ -1931,10 +1931,15 @@ def _contrastive_search( ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + batch_size, cur_len = ( + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape + ) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) this_peer_finished = False # used by synced_gpus only - batch_size = input_ids.shape[0] while True: if synced_gpus: @@ -1975,7 +1980,6 @@ def _contrastive_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, - model_inputs=model_inputs, ) if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). @@ -2170,7 +2174,9 @@ def _contrastive_search( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) # if eos_token was found in one sentence, set sentence to finished @@ -2389,7 +2395,13 @@ def _greedy_search( ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + batch_size, cur_len = ( + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape + ) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) this_peer_finished = False # used by synced_gpus only while True: @@ -2459,7 +2471,6 @@ def _greedy_search( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, ) # if eos_token was found in one sentence, set sentence to finished @@ -2688,7 +2699,13 @@ def _sample( ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + batch_size, cur_len = ( + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape + ) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) this_peer_finished = False # used by synced_gpus only # auto-regressive generation @@ -2758,7 +2775,9 @@ def _sample( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) # if eos_token was found in one sentence, set sentence to finished @@ -3003,6 +3022,7 @@ def _beam_search( num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -3156,7 +3176,9 @@ def _beam_search( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3397,6 +3419,7 @@ def _beam_sample( num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -3510,7 +3533,9 @@ def _beam_sample( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3747,6 +3772,7 @@ def _group_beam_search( device = input_ids.device batch_beam_size, cur_len = input_ids.shape + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if return_dict_in_generate and output_scores: beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] @@ -3916,7 +3942,9 @@ def _group_beam_search( input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4155,6 +4183,7 @@ def _constrained_beam_search( num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -4275,7 +4304,9 @@ def _constrained_beam_search( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4511,7 +4542,13 @@ def _assisted_decoding( ) # keep track of which sequences are already finished - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + batch_size, cur_len = batch_size, cur_len = ( + model_kwargs["attention_mask"].shape + if model_kwargs.get("attention_mask", None) is not None + else input_ids.shape + ) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) # other auxiliary variables max_len = stopping_criteria[0].max_length @@ -4555,6 +4592,14 @@ def _assisted_decoding( candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + if "cache_position" in candidate_kwargs: + candidate_kwargs["cache_position"] = torch.cat( + ( + candidate_kwargs["cache_position"], + torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), + ), + dim=0, + ) model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) @@ -4673,7 +4718,9 @@ def _assisted_decoding( ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) # if eos_token was found in one sentence, set sentence to finished diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 7156fe8e5a9f45..90b63a5ec81429 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -256,7 +256,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -343,7 +343,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -542,7 +542,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -791,6 +791,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -1128,14 +1132,26 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1168,22 +1184,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if self.generation_config.cache_implementation == "static": - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - position_ids = position_ids.contiguous() if position_ids is not None else None - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1193,6 +1193,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index eed75b3522b0a9..0023fd2014940d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1557,10 +1557,12 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, - model_inputs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: model_kwargs = super()._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder, standardize_cache_format, model_inputs + outputs, + model_kwargs, + is_encoder_decoder, + standardize_cache_format, ) if "image_attention_mask" in model_kwargs: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5e2a9c2e5be7d2..7b9cca330e99ae 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -361,7 +361,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -451,7 +451,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -650,7 +650,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -903,6 +903,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -1240,14 +1244,26 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1280,22 +1296,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if self.generation_config.cache_implementation == "static": - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - position_ids = position_ids.contiguous() if position_ids is not None else None - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1305,6 +1305,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids,