From c518ee382e51fa94b117a01b7d2964d7dfbb3cfa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 9 Feb 2024 07:04:16 +0100 Subject: [PATCH 01/30] wow I was scared! --- src/transformers/models/llama/modeling_llama.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c30be2a2da4f..e88513ba9b54 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -589,6 +589,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, @@ -683,6 +684,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -717,6 +719,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + cache_position=cache_position, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -918,6 +921,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -981,6 +985,7 @@ def forward( hidden_states, causal_mask, position_ids, + cache_position, past_key_values, output_attentions, use_cache, @@ -990,6 +995,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, + cache_position=cache_position, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -1104,6 +1110,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1149,6 +1156,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + cache_position=cache_position, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, From 0b20058afdf2ce58f75073a649e9516e6882ce87 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 9 Feb 2024 11:58:58 +0100 Subject: [PATCH 02/30] fix everything --- src/transformers/cache_utils.py | 2 +- src/transformers/models/llama/modeling_llama.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index abdc3c7c0707..ab447bd5dfcb 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -396,7 +396,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - return self.seen_tokens + return 0 def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: return self.seen_tokens diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e88513ba9b54..7f8324612f80 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -300,6 +300,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, @@ -432,6 +433,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -995,8 +999,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, - cache_position=cache_position, position_ids=position_ids, + cache_position=cache_position, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, From 1c5b2c03ed5dcd0cf624ab40a898a3d6baa2419d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 9 Feb 2024 12:33:20 +0100 Subject: [PATCH 03/30] nits --- src/transformers/models/llama/modeling_llama.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7f8324612f80..d56860c907d2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -433,9 +433,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) From 1acc62f08b81b3d1e36fa06ff211dcf625a16e32 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 12 Feb 2024 02:49:34 +0100 Subject: [PATCH 04/30] make it BC? --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d56860c907d2..a6fb98ad4c88 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1019,7 +1019,7 @@ def forward( all_hidden_states += (hidden_states,) next_cache = None - if use_cache: + if use_cache and isinstance(next_decoder_cache, (DynamicCache, SinkCache)): next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) From bd93ac7379bdffa27f7e219f6c1f3fc6479602d3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 12 Feb 2024 03:49:01 +0100 Subject: [PATCH 05/30] nits --- src/transformers/models/llama/modeling_llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a6fb98ad4c88..92a5192fdee7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -402,6 +402,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, @@ -428,6 +429,8 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -1019,7 +1022,8 @@ def forward( all_hidden_states += (hidden_states,) next_cache = None - if use_cache and isinstance(next_decoder_cache, (DynamicCache, SinkCache)): + # if use_cache and isinstance(next_decoder_cache, (DynamicCache, SinkCache)): + if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) From 9ef722ff8feec56bcb56f188563824efb5ad9f7f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 12 Feb 2024 14:18:04 +0900 Subject: [PATCH 06/30] is_tracing should still be used to pass tracing tests --- src/transformers/models/llama/modeling_llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 92a5192fdee7..4a4504318940 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1256,9 +1256,13 @@ def prepare_inputs_for_generation( # same goes for position ids. Could also help with continued generation. cache_position = kwargs.get("cache_position", None) if cache_position is None: +<<<<<<< HEAD cache_position = torch.arange( past_length, past_length + position_ids.shape[-1], device=position_ids.device ) +======= + cache_position = torch.arange(past_length, past_length + input_ids.shape[1]) +>>>>>>> 651c4bd80 (is_tracing should still be used to pass tracing tests) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: From 5078b37afe6f7776cbb64df1d51e1784140f57d5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 12 Feb 2024 15:55:16 +0900 Subject: [PATCH 07/30] nits --- src/transformers/models/llama/modeling_llama.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4a4504318940..92a5192fdee7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1256,13 +1256,9 @@ def prepare_inputs_for_generation( # same goes for position ids. Could also help with continued generation. cache_position = kwargs.get("cache_position", None) if cache_position is None: -<<<<<<< HEAD cache_position = torch.arange( past_length, past_length + position_ids.shape[-1], device=position_ids.device ) -======= - cache_position = torch.arange(past_length, past_length + input_ids.shape[1]) ->>>>>>> 651c4bd80 (is_tracing should still be used to pass tracing tests) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: From 10cc68f3d68c6a8ced4d349475bdb427e751eefb Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 12 Feb 2024 09:50:46 +0100 Subject: [PATCH 08/30] some nits to make sure genration works with static cache uncompiled --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ab447bd5dfcb..abdc3c7c0707 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -396,7 +396,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - return 0 + return self.seen_tokens def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: return self.seen_tokens From f83592eff8f8ce802f9953fcb8e5fa61c9706bcc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 10:00:02 +0900 Subject: [PATCH 09/30] fix FA2 for both static and dynamic in a better way? --- src/transformers/models/llama/modeling_llama.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 92a5192fdee7..200b0c6b6601 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -470,6 +470,9 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) + if attention_mask is not None and 0.0 not in attention_mask and key_states.shape[2] <= q_len: + attention_mask = None + attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) From 87631c81bde6683cfdf848a11a49e26330af88fb Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 10:21:47 +0900 Subject: [PATCH 10/30] fix sequential beam searcg --- src/transformers/generation/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dd8fa604d63e..b508a5c18220 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4776,9 +4776,8 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k=="cache_position"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] # we split the tensors and tuples of tensors data_split_list = [ From c3f3c0bad90c47f86067f8595350c84258682384 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 10:22:00 +0900 Subject: [PATCH 11/30] style --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b508a5c18220..8677fb09eab9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4776,7 +4776,7 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k=="cache_position"] + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] # we split the tensors and tuples of tensors From 561fa32e94355eb6483f67475e47ab331125a7d3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 10:25:35 +0900 Subject: [PATCH 12/30] use `keys_to_ignore` --- src/transformers/generation/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8677fb09eab9..dd8fa604d63e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4777,7 +4777,8 @@ def _split_model_inputs( # ModelOutput object. # bool should not be split but replicated for each split bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] + keys_to_ignore = ["cache_position", "encoder_outputs"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] # we split the tensors and tuples of tensors data_split_list = [ From ed11a7548d1e334281857e3757b245ae68ee9076 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 10:53:54 +0900 Subject: [PATCH 13/30] nit --- src/transformers/models/llama/modeling_llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 200b0c6b6601..d5a2b493b40c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -436,6 +436,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if attention_mask is not None and 0.0 not in attention_mask and key_states.shape[2] <= q_len: + attention_mask = None + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -470,8 +473,6 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - if attention_mask is not None and 0.0 not in attention_mask and key_states.shape[2] <= q_len: - attention_mask = None attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate From d623190f63f9370a778583997981a9b0628002b0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 11:06:14 +0900 Subject: [PATCH 14/30] correct dtype inference when init --- src/transformers/cache_utils.py | 11 +++++++++++ src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index abdc3c7c0707..70e54da91324 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -344,15 +344,26 @@ class StaticCache(Cache): The default `dtype` to use when initializing the layer. """ +<<<<<<< HEAD def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: +======= + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None + ) -> None: +>>>>>>> 9aa667e03 (correct dtype inference when init) super().__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.head_dim = config.hidden_size // config.num_attention_heads +<<<<<<< HEAD self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) +======= + self.num_heads = config.num_attention_heads + self.dtype = dtype if dtype is not None else torch.float32 +>>>>>>> 9aa667e03 (correct dtype inference when init) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d5a2b493b40c..a96cde013453 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -436,7 +436,7 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if attention_mask is not None and 0.0 not in attention_mask and key_states.shape[2] <= q_len: + if attention_mask is not None and 0.0 not in attention_mask and key_states.shape[2] <= query_states.shape[2]: attention_mask = None # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache From c51cc75bc910ac44ffd93f97dc9126521bc79e9c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 11:11:00 +0900 Subject: [PATCH 15/30] :( the fix for FA2 is still not optimal to investigate! --- src/transformers/cache_utils.py | 11 ----------- src/transformers/models/llama/modeling_llama.py | 2 -- 2 files changed, 13 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 70e54da91324..abdc3c7c0707 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -344,26 +344,15 @@ class StaticCache(Cache): The default `dtype` to use when initializing the layer. """ -<<<<<<< HEAD def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: -======= - def __init__( - self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None - ) -> None: ->>>>>>> 9aa667e03 (correct dtype inference when init) super().__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.head_dim = config.hidden_size // config.num_attention_heads -<<<<<<< HEAD self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) -======= - self.num_heads = config.num_attention_heads - self.dtype = dtype if dtype is not None else torch.float32 ->>>>>>> 9aa667e03 (correct dtype inference when init) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a96cde013453..49e0213e702b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -436,8 +436,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if attention_mask is not None and 0.0 not in attention_mask and key_states.shape[2] <= query_states.shape[2]: - attention_mask = None # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. From 8d9e9f46c17928c389f48bd159b98059a62e1804 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 11:26:31 +0900 Subject: [PATCH 16/30] styling --- src/transformers/models/llama/modeling_llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 49e0213e702b..92a5192fdee7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -436,7 +436,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -471,7 +470,6 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) From 417669460754616c7a7918d19819e7695aef4a7e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 12:01:48 +0900 Subject: [PATCH 17/30] nits --- src/transformers/models/llama/modeling_llama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 92a5192fdee7..7048042fdd29 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -435,6 +435,8 @@ def forward( # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if cache_position is not None: + key_states, value_states = key_states[:, :, :cache_position[-1]+1, :], value_states[:, :, :cache_position[-1]+1, :] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -1038,9 +1040,11 @@ def forward( def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None + # since the static cache is padded, you have to pass the attention mask raw. + # similar to https://github.com/facebookresearch/llama/commit/e9077bd24177a74aa79f406bef7d4b57fe393157 + if attention_mask is not None and 0.0 not in attention_mask: + return None + return attention_mask batch_size, seq_length = input_tensor.shape[:2] dtype = input_tensor.dtype From 1936cf8d1e04f8db8f5e37bc2326b3370fde65be Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 12:04:04 +0900 Subject: [PATCH 18/30] nit --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7048042fdd29..12038fd8c70d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1042,7 +1042,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": # since the static cache is padded, you have to pass the attention mask raw. # similar to https://github.com/facebookresearch/llama/commit/e9077bd24177a74aa79f406bef7d4b57fe393157 - if attention_mask is not None and 0.0 not in attention_mask: + if input_tensor.shape[1] == 1: return None return attention_mask From c476ad3230c5ee5994ec99a055d860ea4c6de824 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 12:08:30 +0900 Subject: [PATCH 19/30] this might work better --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 12038fd8c70d..7048042fdd29 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1042,7 +1042,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": # since the static cache is padded, you have to pass the attention mask raw. # similar to https://github.com/facebookresearch/llama/commit/e9077bd24177a74aa79f406bef7d4b57fe393157 - if input_tensor.shape[1] == 1: + if attention_mask is not None and 0.0 not in attention_mask: return None return attention_mask From cfbcf6ab216aeb58a825515fa91b62c2ebaf2aec Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 12:10:10 +0900 Subject: [PATCH 20/30] add comment --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7048042fdd29..dcfe085a85cf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -435,7 +435,7 @@ def forward( # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cache_position is not None: + if cache_position is not None: # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails key_states, value_states = key_states[:, :, :cache_position[-1]+1, :], value_states[:, :, :cache_position[-1]+1, :] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache From 3a2a785e3cc9ffe68db8436aefa5d21babeb2750 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 14 Feb 2024 04:17:50 +0100 Subject: [PATCH 21/30] Update src/transformers/models/llama/modeling_llama.py --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dcfe085a85cf..ef97718984da 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1042,7 +1042,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": # since the static cache is padded, you have to pass the attention mask raw. # similar to https://github.com/facebookresearch/llama/commit/e9077bd24177a74aa79f406bef7d4b57fe393157 - if attention_mask is not None and 0.0 not in attention_mask: + if attention_mask is not None and 0.0 in attention_mask: return None return attention_mask From ed6c60dcdbee456ec101ed36fe067ee124950a80 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 13 Feb 2024 18:25:55 +0000 Subject: [PATCH 22/30] tmp commit --- src/transformers/__init__.py | 4 +- src/transformers/cache_utils.py | 220 ++++++++++-------- src/transformers/generation/utils.py | 22 +- .../models/llama/modeling_llama.py | 138 ++++++----- tests/test_cache_utils.py | 51 ++-- 5 files changed, 254 insertions(+), 181 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 84a664580227..b3671f16c20d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1338,7 +1338,7 @@ _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] - _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"] + _import_structure["cache_utils"] = ["Cache", "DynamicCache", "ModelCache", "SinkCache", "StaticCache"] _import_structure["data.datasets"] = [ "GlueDataset", "GlueDataTrainingArguments", @@ -6086,7 +6086,7 @@ # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments - from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache + from .cache_utils import Cache, DynamicCache, ModelCache, SinkCache, StaticCache from .data.datasets import ( GlueDataset, GlueDataTrainingArguments, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index abdc3c7c0707..e312a69eb3b0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -6,6 +6,90 @@ from .configuration_utils import PretrainedConfig +class ModelCache(): + """ + A standalone class that holds multiple `Cache` instances, behaving exactly like the legacy cache format. Designed + mostly for backwards compatibility purposes, it is used to set up the cache for models, or as an output type for + the `forward` method of models that use caches. + + Parameters: + caches (`List[Cache]`): + A list of `Cache` instances. + """ + + def __init__(self, caches: List["Cache"]) -> None: + self.caches = caches + self.cache_cls = caches[0].__class__ + if not all(isinstance(cache, self.cache_cls) for cache in caches): + raise ValueError(f"All caches in `caches` must be of the same type. Got: {[cache.__class__ for cache in caches]}") + + def __len__(self) -> int: + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.caches) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.caches[layer_idx].key_cache, self.caches[layer_idx].value_cache) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.caches[layer_idx].key_cache, self.caches[layer_idx].value_cache) + + @property + def seen_tokens(self) -> int: + """Returns the total number of tokens seen by the cache.""" + return self.caches[0].seen_tokens + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.caches[0].get_seq_length() + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.caches[0].get_max_length() + + def get_usable_length(self, new_seq_length: int) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + return self.caches[0].get_usable_length(new_seq_length) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for cache in self.caches: + cache.reorder_cache(beam_idx) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `ModelCache` instance into the its equivalent in the legacy cache format.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.caches[layer_idx].key_cache, self.caches[layer_idx].value_cache),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `ModelCache`.""" + caches = [] + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + layer_cache = DynamicCache() + layer_cache.update(key_states, value_states) + caches.append(layer_cache) + return cls(caches) + + @dataclass class Cache: """ @@ -16,19 +100,16 @@ def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states` for a given layer. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. These are specific to each subclass and allow new types of cache to be created. @@ -38,81 +119,55 @@ def update( """ raise NotImplementedError("Make sure to implement `update` in a subclass.") - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states, if there is any.""" raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + def get_usable_length(self, new_seq_length: int) -> int: """Given the sequence length of the new inputs, returns the usable length of the cache.""" # Cache without size limit -> all cache is usable # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache # length, we will need to evict part of the cache (and thus not all cache is usable) max_length = self.get_max_length() - previous_seq_length = self.get_seq_length(layer_idx) + previous_seq_length = self.get_seq_length() if max_length is not None and previous_seq_length + new_seq_length > max_length: return max_length - new_seq_length return previous_seq_length + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + raise NotImplementedError("Make sure to implement `reorder_cache` in a subclass.") + class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. + A cache that grows dynamically as more tokens are generated. This is the default for generative models. The + expected shape for each cached tensor is `[batch_size, num_heads, seq_len, head_dim]`. """ - def __init__(self) -> None: - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] + def __init__(self, **unused_kwargs) -> None: + self.key_cache: Optional[torch.Tensor] = None + self.value_cache: Optional[torch.Tensor] = None self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen - def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states` for a given layer. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. @@ -120,24 +175,23 @@ def update( A tuple containing the updated key and value states. """ # Update the number of seen tokens - if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self.seen_tokens += key_states.shape[-2] # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) + if self.key_cache is None: + self.key_cache = key_states + self.value_cache = value_states else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache= torch.cat([self.key_cache, key_states], dim=-2) + self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) - return self.key_cache[layer_idx], self.value_cache[layer_idx] + return self.key_cache, self.value_cache - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: + if self.key_cache is None: return 0 - return self.key_cache[layer_idx].shape[-2] + return self.key_cache.shape[-2] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" @@ -145,28 +199,10 @@ def get_max_length(self) -> Optional[int]: def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache + device = self.key_cache.device + self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) + device = self.value_cache.device + self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) class SinkCache(Cache): @@ -185,7 +221,7 @@ class SinkCache(Cache): The number of sink tokens. See the original paper for more information. """ - def __init__(self, window_length: int, num_sink_tokens: int) -> None: + def __init__(self, window_length: int, num_sink_tokens: int, **unused_kwargs) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] self.window_length = window_length @@ -344,15 +380,18 @@ class StaticCache(Cache): The default `dtype` to use when initializing the layer. """ - def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32, **unused_kwargs + ) -> None: super().__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.head_dim = config.hidden_size // config.num_attention_heads - self.dtype = dtype if dtype is not None else torch.float32 - self.num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads - ) + if hasattr(config, "num_key_value_heads") and config.num_key_value_heads is not None: + self.num_key_value_heads = config.num_key_value_heads + else: + self.num_key_value_heads = config.num_attention_heads + self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) @@ -363,11 +402,10 @@ def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states` for a given layer. It is VERY important to index using a tensor, otherwise you introduce a copy to the device. Parameters: @@ -375,8 +413,6 @@ def update( The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. Kept for backward compatibility cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` to know how much of the cache it should overwrite. @@ -394,15 +430,15 @@ def update( self.seen_tokens += key_states.shape[2] return k_out, v_out - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" return self.seen_tokens def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: return self.seen_tokens def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + """Returns the maximum sequence length of the cached states.""" return self.max_cache_len def reorder_cache(self, beam_idx: torch.LongTensor): @@ -411,7 +447,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) device = self.value_cache.device self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self): - """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" - return None diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dd8fa604d63e..6ad49faee9c3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1493,7 +1493,7 @@ def generate( ) # 12. run assisted generate - return self.assisted_decoding( + generate_output = self.assisted_decoding( input_ids, candidate_generator=candidate_generator, do_sample=generation_config.do_sample, @@ -1510,7 +1510,7 @@ def generate( ) if generation_mode == GenerationMode.GREEDY_SEARCH: # 11. run greedy search - return self.greedy_search( + generate_output = self.greedy_search( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, @@ -1527,7 +1527,7 @@ def generate( if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") - return self.contrastive_search( + generate_output = self.contrastive_search( input_ids, top_k=generation_config.top_k, penalty_alpha=generation_config.penalty_alpha, @@ -1556,7 +1556,7 @@ def generate( ) # 13. run sample - return self.sample( + generate_output = self.sample( input_ids, logits_processor=prepared_logits_processor, logits_warper=logits_warper, @@ -1589,7 +1589,7 @@ def generate( **model_kwargs, ) # 13. run beam search - return self.beam_search( + generate_output = self.beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1627,7 +1627,7 @@ def generate( ) # 14. run beam sample - return self.beam_sample( + generate_output = self.beam_sample( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1661,7 +1661,7 @@ def generate( **model_kwargs, ) # 13. run beam search - return self.group_beam_search( + generate_output = self.group_beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1734,7 +1734,7 @@ def typeerror(): **model_kwargs, ) # 13. run beam search - return self.constrained_beam_search( + generate_output = self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=prepared_logits_processor, @@ -1747,6 +1747,12 @@ def typeerror(): **model_kwargs, ) + # Finally, reset the model cache if has one + if hasattr(self, "_reset_cache"): + self._reset_cache() + + return generate_output + @torch.no_grad() def contrastive_search( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ef97718984da..46ffa8aee631 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -20,7 +20,7 @@ """ PyTorch LLaMA model.""" import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, ModelCache, StaticCache from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -235,17 +235,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: LlamaConfig): super().__init__() self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -255,6 +247,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self.past_key_value = None if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -339,7 +332,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - if past_key_value is not None: + if self.past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -379,7 +372,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, self.past_key_value class LlamaFlashAttention2(LlamaAttention): @@ -482,7 +475,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, self.past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None @@ -632,7 +625,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) - if past_key_value is not None: + if self.past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -664,7 +657,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, self.past_key_value LLAMA_ATTENTION_CLASSES = { @@ -675,11 +668,11 @@ def forward( class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, layer_idx: int): + def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -706,10 +699,6 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if "padding_mask" in kwargs: warnings.warn( @@ -721,7 +710,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -745,9 +734,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -793,27 +779,49 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + def _setup_cache( + self, + external_cache: Optional[ModelCache] = None, + cache_cls: Optional[Cache] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Sets the key value cache for each layer. If `external_cache` is provided, the cache is copied. Otherwise, + instantiates a fresh cache given `cache_cls` and `cache_kwargs`. + """ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) - if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: - causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + model = getattr(self, "model", self) + if external_cache is not None: + for layer_idx, layer in enumerate(model.layers): + layer.self_attn.past_key_value = external_cache.caches[layer_idx] + elif cache_cls is not None: + if cache_kwargs is None: + cache_kwargs = {} + for layer in model.layers: + layer.self_attn.past_key_value = cache_cls(**cache_kwargs) + else: + raise ValueError("Error setting up the cache: `cache_cls` and `external_cache` are both None, one of them must be defined") - for layer in self.model.layers: - weights = layer.self_attn.o_proj.weight - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype - ) + max_cache_len = model.layers[0].self_attn.past_key_value.get_max_length() + if max_cache_len is not None: + if max_cache_len > model.causal_mask.shape[-1] or self.device != model.causal_mask.device: + causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) def _reset_cache(self): - for layer in self.model.layers: + model = getattr(self, "model", self) + for layer in model.layers: layer.self_attn.past_key_value = None + def _has_cache(self): + model = getattr(self, "model", self) + return model.layers[0].self_attn.past_key_value is not None + LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -850,23 +858,23 @@ def _reset_cache(self): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + past_key_values (`ModelCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + blocks) that can be used to set up a cache from previously computed values and speed up sequential + decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of + decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - - a [`~cache_utils.Cache`] instance; + - a [`~cache_utils.ModelCache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. + The model will output a [`~cache_utils.ModelCache`] instance holding the updated cache. Unless the legacy + cache format is used, this updated cache is also set internally as an attribute of the model. When a model + has a set cache, the user doesn't need to feed in `past_key_values` to continue a generation, and can + optionally input only the last `input_ids` (those that don't have their past key value states cached in the + model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -903,9 +911,7 @@ def __init__(self, config: LlamaConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -954,6 +960,21 @@ def forward( ) use_cache = False + # Cache setting priority, when `use_cache` is `True` + # 1. if `past_key_values` is passed, setup the cache using it. Note that if it is not a `ModelCache`, it is + # assumed the user expect the legacy API, where the model instance does NOT hold the cache. + # 2. if `past_key_values` is not passed, use a previously set up cache + # 3. otherwise, set up a dynamic cache + legacy_cache = False + if use_cache: + if past_key_values is not None: + if not isinstance(past_key_values, ModelCache): + past_key_values = ModelCache.from_legacy_cache(past_key_values) + legacy_cache = True + self._setup_cache(external_cache=past_key_values) + elif not self._has_cache(): + self._setup_cache(cache_cls=DynamicCache) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -979,7 +1000,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1011,9 +1031,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1026,9 +1043,10 @@ def forward( next_cache = None # if use_cache and isinstance(next_decoder_cache, (DynamicCache, SinkCache)): if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) + next_cache = ModelCache([layer.self_attn.past_key_value for layer in self.layers]) + if legacy_cache: + self._reset_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1214,7 +1232,7 @@ def prepare_inputs_for_generation( ): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): + if isinstance(past_key_values, ModelCache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() @@ -1248,7 +1266,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = position_ids[:, -input_ids.shape[1]:] if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): # generation with static cache @@ -1328,7 +1346,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor] | ModelCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1443,7 +1461,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor] | ModelCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 5f3af2acf572..80db47535d05 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -37,6 +37,7 @@ DynamicCache, LlamaConfig, LlamaForCausalLM, + ModelCache, SinkCache, StaticCache, ) @@ -44,22 +45,23 @@ @require_torch class CacheTest(unittest.TestCase): - def test_dynamic_cache_retrocompatibility(self): - """Tests that we can convert back and forth between the legacy cache format and DynamicCache""" + def test_model_cache_retrocompatibility(self): + """Tests that we can convert back and forth between the legacy cache format and ModelCache""" + num_layers = 10 legacy_cache = () - new_cache = DynamicCache() + new_cache = ModelCache([DynamicCache() for _ in range(num_layers)]) - # Creates a new cache with 10 layers in both formats - for layer_idx in range(10): + # Creates a new cache with `num_layers` layers in both formats + for layer_idx in range(num_layers): new_key = torch.rand((2, 4, 8, 16)) new_value = torch.rand((2, 4, 8, 16)) - new_cache.update(new_key, new_value, layer_idx) + new_cache.caches[layer_idx].update(new_key, new_value) legacy_cache += ((new_key, new_value),) # Sanity check 1: they must have the same shapes self.assertTrue(len(legacy_cache), len(new_cache)) - for layer_idx in range(10): - self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx])) + for layer_idx in range(num_layers): + self.assertTrue(len(legacy_cache[layer_idx]), len(new_cache[layer_idx])) for key_value_idx in range(2): self.assertTrue( legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape @@ -70,15 +72,15 @@ def test_dynamic_cache_retrocompatibility(self): self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8) # Sanity check 3: they must be equal, and both support indexing - for layer_idx in range(10): + for layer_idx in range(num_layers): for key_value_idx in range(2): self.assertTrue( torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx]) ) # Test 1: We can convert from legacy to new with no changes - from_legacy = DynamicCache.from_legacy_cache(legacy_cache) - for layer_idx in range(10): + from_legacy = ModelCache.from_legacy_cache(legacy_cache) + for layer_idx in range(num_layers): for key_value_idx in range(2): self.assertTrue( torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx]) @@ -86,7 +88,7 @@ def test_dynamic_cache_retrocompatibility(self): # Test 2: We can convert from new to legacy with no changes to_legacy = new_cache.to_legacy_cache() - for layer_idx in range(10): + for layer_idx in range(num_layers): for key_value_idx in range(2): self.assertTrue( torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx]) @@ -96,14 +98,15 @@ def test_reorder_cache_retrocompatibility(self): """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function + num_layers = 10 legacy_cache = () - new_cache = DynamicCache() + new_cache = ModelCache([DynamicCache() for _ in range(num_layers)]) # Creates a new cache with 10 layers in both formats - for layer_idx in range(10): + for layer_idx in range(num_layers): new_key = torch.rand((4, 4, 8, 16)) new_value = torch.rand((4, 4, 8, 16)) - new_cache.update(new_key, new_value, layer_idx) + new_cache.caches[layer_idx].update(new_key, new_value) legacy_cache += ((new_key, new_value),) # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4 @@ -114,7 +117,7 @@ def test_reorder_cache_retrocompatibility(self): new_cache.reorder_cache(beam_idx) # Let's check that the results are the same - for layer_idx in range(10): + for layer_idx in range(num_layers): for key_value_idx in range(2): self.assertTrue( torch.allclose( @@ -143,7 +146,11 @@ def _random_kvs(config): mha_config = LlamaConfig(num_attention_heads=32) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( +<<<<<<< HEAD *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)} +======= + *_random_kvs(mha_config), cache_kwargs={"position_ids": torch.arange(1)} +>>>>>>> 7e8652f6d (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128)) @@ -151,7 +158,11 @@ def _random_kvs(config): gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( +<<<<<<< HEAD *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} +======= + *_random_kvs(gqa_config), cache_kwargs={"position_ids": torch.arange(1)} +>>>>>>> 7e8652f6d (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128)) @@ -159,7 +170,11 @@ def _random_kvs(config): mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( +<<<<<<< HEAD *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} +======= + *_random_kvs(mqa_config), cache_kwargs={"position_ids": torch.arange(1)} +>>>>>>> 7e8652f6d (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) @@ -179,7 +194,8 @@ def test_dynamic_cache_hard(self): set_seed(0) gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) set_seed(0) - gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache()) + external_dynamic_cache = ModelCache([DynamicCache() for _ in range(model.config.num_hidden_layers)]) + gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=external_dynamic_cache) self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist()) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) @@ -195,6 +211,7 @@ def test_dynamic_cache_hard(self): "also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to " "clean their litter box.\nCats are also very independent. They don't" ) + print(decoded[0]) self.assertEqual(decoded[0], expected_text) def test_dynamic_cache_batched(self): From 3f0f207bbdfa5d7dd9a9eb5c4398b3f8bf8b1ff4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 13 Feb 2024 18:35:36 +0000 Subject: [PATCH 23/30] wip --- src/transformers/cache_utils.py | 16 ++++++++++++---- src/transformers/models/llama/modeling_llama.py | 6 ++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index e312a69eb3b0..ec1c8027fcd6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -6,7 +6,7 @@ from .configuration_utils import PretrainedConfig -class ModelCache(): +class ModelCache: """ A standalone class that holds multiple `Cache` instances, behaving exactly like the legacy cache format. Designed mostly for backwards compatibility purposes, it is used to set up the cache for models, or as an output type for @@ -21,7 +21,9 @@ def __init__(self, caches: List["Cache"]) -> None: self.caches = caches self.cache_cls = caches[0].__class__ if not all(isinstance(cache, self.cache_cls) for cache in caches): - raise ValueError(f"All caches in `caches` must be of the same type. Got: {[cache.__class__ for cache in caches]}") + raise ValueError( + f"All caches in `caches` must be of the same type. Got: {[cache.__class__ for cache in caches]}" + ) def __len__(self) -> int: """ @@ -182,7 +184,7 @@ def update( self.key_cache = key_states self.value_cache = value_states else: - self.key_cache= torch.cat([self.key_cache, key_states], dim=-2) + self.key_cache = torch.cat([self.key_cache, key_states], dim=-2) self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) return self.key_cache, self.value_cache @@ -381,7 +383,13 @@ class StaticCache(Cache): """ def __init__( - self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32, **unused_kwargs + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int, + device, + dtype=torch.float32, + **unused_kwargs, ) -> None: super().__init__() self.max_batch_size = max_batch_size diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 46ffa8aee631..7c7e114b650e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -805,7 +805,9 @@ def _setup_cache( for layer in model.layers: layer.self_attn.past_key_value = cache_cls(**cache_kwargs) else: - raise ValueError("Error setting up the cache: `cache_cls` and `external_cache` are both None, one of them must be defined") + raise ValueError( + "Error setting up the cache: `cache_cls` and `external_cache` are both None, one of them must be defined" + ) max_cache_len = model.layers[0].self_attn.past_key_value.get_max_length() if max_cache_len is not None: @@ -1266,7 +1268,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1]:] + position_ids = position_ids[:, -input_ids.shape[1] :] if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): # generation with static cache From 93c9e2e0f510c390224304e3fd24f90dcd9242e7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Feb 2024 12:20:46 +0000 Subject: [PATCH 24/30] tmp commit --- src/transformers/cache_utils.py | 2 +- src/transformers/generation/utils.py | 10 +- .../models/llama/modeling_llama.py | 111 ++++++++---------- tests/test_cache_utils.py | 31 +++-- 4 files changed, 80 insertions(+), 74 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ec1c8027fcd6..5297e0fbd07e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -442,7 +442,7 @@ def get_seq_length(self) -> int: """Returns the sequence length of the cached states that were seen by the model.""" return self.seen_tokens - def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: + def get_usable_length(self, new_sequence_length=None) -> int: return self.seen_tokens def get_max_length(self) -> Optional[int]: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6ad49faee9c3..1bbfe857f157 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, StaticCache +from ..cache_utils import Cache, DynamicCache, ModelCache, StaticCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -2741,17 +2741,17 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx): # Exception 1: code path for models using the legacy cache format if isinstance(past_key_values, (tuple, list)): past_key_values = self._reorder_cache(past_key_values, beam_idx) - # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their + # Exception 2: models with different cache formats. These are limited to `DynamicCache` caches until their # cache format is standardized, to avoid adding complexity to the codebase. elif "bloom" in model_class or "gptbigcode" in model_class: - if not isinstance(past_key_values, DynamicCache): + if not isinstance(past_key_values.caches[0], DynamicCache): raise ValueError( f"Using an unsupported cache format with {model_class}. Currently, it only supports the " "legacy tuple format or `DynamicCache`" ) past_key_values = self._reorder_cache(past_key_values, beam_idx) - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # Standard code path: use the `Cache.reorder_cache` + past_key_values = ModelCache.from_legacy_cache(past_key_values) + # Standard code path: use the cache's `.reorder_cache` else: past_key_values.reorder_cache(beam_idx) return past_key_values diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7c7e114b650e..12892b738a35 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -294,7 +294,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -328,14 +327,12 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - if self.past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -396,7 +393,6 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -420,16 +416,17 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - past_key_value = getattr(self, "past_key_value", past_key_value) - - past_key_value = getattr(self, "past_key_value", past_key_value) - - if past_key_value is not None: + if self.past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if cache_position is not None: # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails - key_states, value_states = key_states[:, :, :cache_position[-1]+1, :], value_states[:, :, :cache_position[-1]+1, :] + key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs) + if ( + cache_position is not None + ): # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails + key_states, value_states = ( + key_states[:, :, : cache_position[-1] + 1, :], + value_states[:, :, : cache_position[-1] + 1, :], + ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -589,7 +586,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -623,12 +619,10 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - past_key_value = getattr(self, "past_key_value", past_key_value) - if self.past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -684,7 +678,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -715,7 +708,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, - past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -815,14 +807,38 @@ def _setup_cache( causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + def _setup_cache_from_past_key_values(self, past_key_values: List[torch.FloatTensor] | ModelCache): + """ + Sets the model cache from the `past_key_values` argument, if needed. + + Cache setting priority, when `use_cache` is `True` + 1. if `past_key_values` is passed, setup the cache using it. Note that if it is not a `ModelCache`, it is + assumed the user expect the legacy API, where the model instance does NOT hold the cache. + 2. if `past_key_values` is not passed, use a previously set up cache when it exists + 3. otherwise, set up a new dynamic cache + """ + if past_key_values is not None: + if not isinstance(past_key_values, ModelCache): + past_key_values = ModelCache.from_legacy_cache(past_key_values) + self._setup_cache(external_cache=past_key_values) + elif self._get_cache() is None: + self._setup_cache(cache_cls=DynamicCache) + def _reset_cache(self): model = getattr(self, "model", self) for layer in model.layers: layer.self_attn.past_key_value = None - def _has_cache(self): + def _get_cache(self, all_layers: bool = False) -> Cache | ModelCache: + """ + Returns the `Cache` from the first layer or, if `all_layers` is `True`, a `ModelCache` instance containing + all layers' caches + """ model = getattr(self, "model", self) - return model.layers[0].self_attn.past_key_value is not None + if all_layers: + return ModelCache([layer.self_attn.past_key_value for layer in model.layers]) + else: + return model.layers[0].self_attn.past_key_value LLAMA_INPUTS_DOCSTRING = r""" @@ -936,7 +952,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor] | ModelCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -962,30 +978,15 @@ def forward( ) use_cache = False - # Cache setting priority, when `use_cache` is `True` - # 1. if `past_key_values` is passed, setup the cache using it. Note that if it is not a `ModelCache`, it is - # assumed the user expect the legacy API, where the model instance does NOT hold the cache. - # 2. if `past_key_values` is not passed, use a previously set up cache - # 3. otherwise, set up a dynamic cache - legacy_cache = False + legacy_cache = past_key_values is not None and not isinstance(past_key_values, ModelCache) + past_seen_tokens = 0 if use_cache: - if past_key_values is not None: - if not isinstance(past_key_values, ModelCache): - past_key_values = ModelCache.from_legacy_cache(past_key_values) - legacy_cache = True - self._setup_cache(external_cache=past_key_values) - elif not self._has_cache(): - self._setup_cache(cache_cls=DynamicCache) + self._setup_cache_from_past_key_values(past_key_values) + past_seen_tokens = self._get_cache().get_seq_length() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device @@ -1014,7 +1015,6 @@ def forward( causal_mask, position_ids, cache_position, - past_key_values, output_attentions, use_cache, cache_position, @@ -1025,7 +1025,6 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, cache_position=cache_position, - past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1043,10 +1042,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = None - # if use_cache and isinstance(next_decoder_cache, (DynamicCache, SinkCache)): if use_cache: - next_cache = ModelCache([layer.self_attn.past_key_value for layer in self.layers]) - if legacy_cache: + next_cache = self._get_cache(all_layers=True) + if legacy_cache: # Legacy behavior: the model does NOT hold the cache between forward passes self._reset_cache() if not return_dict: @@ -1232,16 +1230,13 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, ModelCache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + self._setup_cache_from_past_key_values(past_key_values) + cache = self._get_cache() + cache_length = cache.get_seq_length() # number of valid tokens in the cache + past_length = cache.seen_tokens # number of tokens that went through the model (may be > than `cache_length`) + max_cache_length = cache.get_max_length() # cache maximum length, if it is a limited size cache + if past_length > 0: # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as @@ -1270,12 +1265,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): - # generation with static cache - past_length = past_key_value.get_seq_length() - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. cache_position = kwargs.get("cache_position", None) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 80db47535d05..a45dab06b621 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -146,11 +146,15 @@ def _random_kvs(config): mha_config = LlamaConfig(num_attention_heads=32) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( +<<<<<<< HEAD <<<<<<< HEAD *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)} ======= *_random_kvs(mha_config), cache_kwargs={"position_ids": torch.arange(1)} >>>>>>> 7e8652f6d (tmp commit) +======= + *_random_kvs(mha_config), cache_kwargs={"cache_position": torch.arange(1)} +>>>>>>> 1f88b6d90 (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128)) @@ -158,11 +162,15 @@ def _random_kvs(config): gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( +<<<<<<< HEAD <<<<<<< HEAD *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} ======= *_random_kvs(gqa_config), cache_kwargs={"position_ids": torch.arange(1)} >>>>>>> 7e8652f6d (tmp commit) +======= + *_random_kvs(gqa_config), cache_kwargs={"cache_position": torch.arange(1)} +>>>>>>> 1f88b6d90 (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128)) @@ -170,11 +178,15 @@ def _random_kvs(config): mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( +<<<<<<< HEAD <<<<<<< HEAD *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} ======= *_random_kvs(mqa_config), cache_kwargs={"position_ids": torch.arange(1)} >>>>>>> 7e8652f6d (tmp commit) +======= + *_random_kvs(mqa_config), cache_kwargs={"cache_position": torch.arange(1)} +>>>>>>> 1f88b6d90 (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) @@ -224,7 +236,8 @@ def test_dynamic_cache_batched(self): model.device ) - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache()) + external_dynamic_cache = ModelCache([DynamicCache() for _ in range(model.config.num_hidden_layers)]) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=external_dynamic_cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] self.assertListEqual(decoded, expected_text) @@ -259,8 +272,10 @@ def test_sink_cache_hard(self): # Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run # without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of") - cache = SinkCache(window_length=508, num_sink_tokens=4) - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) + external_sink_cache = ModelCache( + [SinkCache(window_length=508, num_sink_tokens=4) for _ in range(model.config.num_hidden_layers)] + ) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=external_sink_cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) @@ -276,7 +291,9 @@ def test_sink_cache_iterative_prompts(self): ) # Prepare generation settings - cache = SinkCache(window_length=256, num_sink_tokens=4) + external_sink_cache = ModelCache( + [SinkCache(window_length=256, num_sink_tokens=4) for _ in range(model.config.num_hidden_layers)] + ) input_ids = torch.tensor([], device=model.device, dtype=torch.int) for _ in range(3): # Tokenize the prompt with the correct chat template @@ -288,12 +305,12 @@ def test_sink_cache_iterative_prompts(self): # Perform the generation gen_out = model.generate( - input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True + input_ids, do_sample=False, max_new_tokens=100, past_key_values=external_sink_cache, use_cache=True ) input_ids = gen_out # We went well beyond the cache length - self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) + self.assertTrue(input_ids.shape[1] > external_sink_cache.get_max_length() * 1.5) # And it still produces a coherent english decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) @@ -361,7 +378,7 @@ def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): "NousResearch/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, - ).to("cuda:1") + ).to(torch_device) inputs = tokenizer( ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" ).to(model.device) From e69eec220d070cbc219953cc8d316a78ae9fe116 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Feb 2024 14:05:52 +0000 Subject: [PATCH 25/30] tmp --- src/transformers/models/llama/modeling_llama.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 12892b738a35..3ba6fa87314f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1262,8 +1262,7 @@ def prepare_inputs_for_generation( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = position_ids[:, -input_ids.shape[1] :] # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. @@ -1274,7 +1273,7 @@ def prepare_inputs_for_generation( ) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} @@ -1283,22 +1282,12 @@ def prepare_inputs_for_generation( { "position_ids": position_ids, "cache_position": cache_position, - "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @add_start_docstrings( """ From 5b38bf7712a14ba8b9c8d807b8654590c8ba923f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 Feb 2024 15:25:53 +0000 Subject: [PATCH 26/30] merge errors --- .../models/llama/modeling_llama.py | 19 --------------- src/transformers/models/phi/modeling_phi.py | 1 - tests/test_cache_utils.py | 24 ------------------- 3 files changed, 44 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3ba6fa87314f..b5032f08ab9d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -295,8 +295,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -394,8 +392,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -587,8 +583,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -600,9 +594,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, cache_position=cache_position, ) @@ -679,8 +671,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -709,8 +699,6 @@ def forward( position_ids=position_ids, cache_position=cache_position, output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states @@ -958,7 +946,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1016,8 +1003,6 @@ def forward( position_ids, cache_position, output_attentions, - use_cache, - cache_position, ) else: layer_outputs = decoder_layer( @@ -1026,8 +1011,6 @@ def forward( position_ids=position_ids, cache_position=cache_position, output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1145,7 +1128,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1190,7 +1172,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) hidden_states = outputs[0] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 799fe02c8f48..2ee9a150b7bc 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1142,7 +1142,6 @@ def prepare_inputs_for_generation( return model_inputs @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index a45dab06b621..4a581dc8adda 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -146,15 +146,7 @@ def _random_kvs(config): mha_config = LlamaConfig(num_attention_heads=32) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( -<<<<<<< HEAD -<<<<<<< HEAD - *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)} -======= - *_random_kvs(mha_config), cache_kwargs={"position_ids": torch.arange(1)} ->>>>>>> 7e8652f6d (tmp commit) -======= *_random_kvs(mha_config), cache_kwargs={"cache_position": torch.arange(1)} ->>>>>>> 1f88b6d90 (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128)) @@ -162,15 +154,7 @@ def _random_kvs(config): gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( -<<<<<<< HEAD -<<<<<<< HEAD - *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} -======= - *_random_kvs(gqa_config), cache_kwargs={"position_ids": torch.arange(1)} ->>>>>>> 7e8652f6d (tmp commit) -======= *_random_kvs(gqa_config), cache_kwargs={"cache_position": torch.arange(1)} ->>>>>>> 1f88b6d90 (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128)) @@ -178,15 +162,7 @@ def _random_kvs(config): mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( -<<<<<<< HEAD -<<<<<<< HEAD - *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} -======= - *_random_kvs(mqa_config), cache_kwargs={"position_ids": torch.arange(1)} ->>>>>>> 7e8652f6d (tmp commit) -======= *_random_kvs(mqa_config), cache_kwargs={"cache_position": torch.arange(1)} ->>>>>>> 1f88b6d90 (tmp commit) ) self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) From bc84704d06cb4e40e53eef27c66d33e29f621be4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 Feb 2024 16:55:52 +0000 Subject: [PATCH 27/30] reduce diff --- src/transformers/models/llama/modeling_llama.py | 15 ++++----------- tests/test_cache_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b5032f08ab9d..76f8de70038d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -293,8 +293,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -390,8 +390,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -416,13 +416,6 @@ def forward( # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs) - if ( - cache_position is not None - ): # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails - key_states, value_states = ( - key_states[:, :, : cache_position[-1] + 1, :], - value_states[:, :, : cache_position[-1] + 1, :], - ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -581,8 +574,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -669,8 +662,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 4a581dc8adda..17f2fabcffe5 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -36,7 +36,7 @@ AutoTokenizer, DynamicCache, LlamaConfig, - LlamaForCausalLM, + MistralForCausalLM, ModelCache, SinkCache, StaticCache, @@ -96,7 +96,7 @@ def test_model_cache_retrocompatibility(self): def test_reorder_cache_retrocompatibility(self): """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" - legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function + legacy_reorder_fn = MistralForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function num_layers = 10 legacy_cache = () From 958875e1457c117deca2dbfbb09060b735acc034 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 Feb 2024 17:01:42 +0000 Subject: [PATCH 28/30] smaller llama diff --- src/transformers/models/llama/modeling_llama.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 76f8de70038d..9fc585307645 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -932,13 +932,13 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor] | ModelCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -994,16 +994,16 @@ def forward( hidden_states, causal_mask, position_ids, - cache_position, output_attentions, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - cache_position=cache_position, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1034,11 +1034,9 @@ def forward( def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": - # since the static cache is padded, you have to pass the attention mask raw. - # similar to https://github.com/facebookresearch/llama/commit/e9077bd24177a74aa79f406bef7d4b57fe393157 if attention_mask is not None and 0.0 in attention_mask: - return None - return attention_mask + return attention_mask + return None batch_size, seq_length = input_tensor.shape[:2] dtype = input_tensor.dtype @@ -1113,7 +1111,6 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1121,6 +1118,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1158,13 +1156,13 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - cache_position=cache_position, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] From 104b208a86b542d1717cffb4fdfcd7e8539539a5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 Feb 2024 17:56:01 +0000 Subject: [PATCH 29/30] make fixup --- docs/source/en/internal/generation_utils.md | 21 ++++-- src/transformers/cache_utils.py | 67 +++++++++---------- .../models/llama/modeling_llama.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 7 ++ 4 files changed, 56 insertions(+), 41 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 452921d88c0e..d624ddf5e745 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -359,21 +359,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens ## Caches -[[autodoc]] Cache +[[autodoc]] ModelCache - update + - get_seq_length + - get_max_length + - get_usable_length + - reorder_cache + - to_legacy_cache + - from_legacy_cache + +[[autodoc]] Cache [[autodoc]] DynamicCache - update - get_seq_length + - get_max_length + - get_usable_length - reorder_cache - - to_legacy_cache - - from_legacy_cache [[autodoc]] SinkCache - update - get_seq_length + - get_max_length + - get_usable_length - reorder_cache [[autodoc]] StaticCache - update - - get_seq_length \ No newline at end of file + - get_seq_length + - get_max_length + - get_usable_length + - reorder_cache diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5297e0fbd07e..d8e0b29dc552 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -8,9 +8,9 @@ class ModelCache: """ - A standalone class that holds multiple `Cache` instances, behaving exactly like the legacy cache format. Designed - mostly for backwards compatibility purposes, it is used to set up the cache for models, or as an output type for - the `forward` method of models that use caches. + A standalone class that holds multiple `Cache` instances, behaving exactly like the legacy cache format. + Designed mostly for backwards compatibility purposes, it is used to set up the cache for models, or as an output + type for the `forward` method of models that use caches. Parameters: caches (`List[Cache]`): @@ -224,8 +224,8 @@ class SinkCache(Cache): """ def __init__(self, window_length: int, num_sink_tokens: int, **unused_kwargs) -> None: - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] + self.key_cache: Optional[torch.Tensor] = None + self.value_cache: Optional[torch.Tensor] = None self.window_length = window_length self.num_sink_tokens = num_sink_tokens self.cos_sin_cache = {} @@ -265,12 +265,12 @@ def _get_rerotation_cos_sin( ) return self.cos_sin_cache[key_states.shape[-2]] - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + # Copied from transformers.cache_utils.DynamicCache.get_seq_length + def get_seq_length(self) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length - if len(self.key_cache) <= layer_idx: + if self.key_cache is None: return 0 - return self.key_cache[layer_idx].shape[-2] + return self.key_cache.shape[-2] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" @@ -280,19 +280,16 @@ def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states` for a given layer. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the @@ -309,25 +306,22 @@ def update( using_rope = cos is not None and sin is not None # Update the number of seen tokens - if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self.seen_tokens += key_states.shape[-2] # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: + if self.key_cache is None: # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) + self.key_cache = key_states + self.value_cache = value_states - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + elif key_states.shape[-2] + self.get_seq_length() < self.window_length: # Growing cache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache = torch.cat([self.key_cache, key_states], dim=-2) + self.value_cache = torch.cat([self.value_cache, value_states], dim=-2) else: # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : - ] + keys_to_keep = self.key_cache[:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :] # On RoPE models, we need to recompute the Key rotation as the tokens are shifted if using_rope: @@ -344,29 +338,29 @@ def update( keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + sink_keys = self.key_cache[:, :, : self.num_sink_tokens] + self.key_cache = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ + sink_values = self.value_cache[:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : ] - self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) + self.value_cache = torch.cat([sink_values, values_to_keep, value_states], dim=-2) - return self.key_cache[layer_idx], self.value_cache[layer_idx] + return self.key_cache, self.value_cache + # Copied from transformers.cache_utils.DynamicCache.reorder_cache def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.key_cache.device + self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) + device = self.value_cache.device + self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) class StaticCache(Cache): """ - Static Cache class to be used with `torch.compile(model)`. + Static cache class to be used with `torch.compile(model)`. Parameters: config (`PretrainedConfig): @@ -449,6 +443,7 @@ def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len + # Copied from transformers.cache_utils.DynamicCache.reorder_cache def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" device = self.key_cache.device diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9fc585307645..28d7a2c8d7b0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1298,7 +1298,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor] | ModelCache] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3b8316ba5472..cb01ae1a91dd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -30,6 +30,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ModelCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SinkCache(metaclass=DummyObject): _backends = ["torch"] From 425c6edebe333f3ccade359debba8d711be7a971 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 Feb 2024 18:45:36 +0000 Subject: [PATCH 30/30] nearly all tests passing --- src/transformers/generation/utils.py | 10 +++++++++- src/transformers/models/llama/modeling_llama.py | 12 +++++++----- src/transformers/models/mistral/modeling_mistral.py | 3 ++- tests/test_cache_utils.py | 6 +++--- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1bbfe857f157..38f00a45bc39 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1431,7 +1431,15 @@ def generate( "The `generation_config` defines a `cache_implementation` that is not compatible with this model." " Make sure it has a `_setup_cache` function." ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + self._setup_cache( + cache_cls=cache_cls, + cache_kwargs={ + "max_batch_size": batch_size, + "max_cache_len": generation_config.max_length, + "config": self.config, + "device": self.device, + }, + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 28d7a2c8d7b0..1d90b033c70d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -768,19 +768,21 @@ def _setup_cache( "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) + if external_cache is not None and (cache_cls is not None or cache_kwargs is not None): + raise ValueError( + "Error setting up the cache: (`external_cache`) and (`cache_cls`, `cache_kwargs`) are simulatenously " + "defined, only one of the sets can be defined at once." + ) + model = getattr(self, "model", self) if external_cache is not None: for layer_idx, layer in enumerate(model.layers): layer.self_attn.past_key_value = external_cache.caches[layer_idx] - elif cache_cls is not None: + else: # cache_cls is not None: if cache_kwargs is None: cache_kwargs = {} for layer in model.layers: layer.self_attn.past_key_value = cache_cls(**cache_kwargs) - else: - raise ValueError( - "Error setting up the cache: `cache_cls` and `external_cache` are both None, one of them must be defined" - ) max_cache_len = model.layers[0].self_attn.past_key_value.get_max_length() if max_cache_len is not None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f4251b98304c..21007e2f4947 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -135,7 +135,8 @@ def rotate_half(x): # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -# TODO @Arthur no longer copied from LLama after static cache + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 17f2fabcffe5..b449a8f70b63 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -301,7 +301,7 @@ def test_sink_cache_iterative_prompts(self): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the subject you are photograph", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -341,14 +341,14 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ "The best color is\n\n\n\n\n\n\n\n\n\n", "We should not undermind the issues at hand, but address them head on.\nI think", ] tokenizer = AutoTokenizer.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + "NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="" ) model = AutoModelForCausalLM.from_pretrained( "NousResearch/Llama-2-7b-chat-hf",