Skip to content

Commit

Permalink
tmp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Feb 14, 2024
1 parent f6ff005 commit 1f88b6d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 76 deletions.
2 changes: 1 addition & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
return self.seen_tokens

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
def get_usable_length(self, new_sequence_length=None) -> int:
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, ModelCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -2741,17 +2741,17 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx):
# Exception 1: code path for models using the legacy cache format
if isinstance(past_key_values, (tuple, list)):
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# Exception 2: models with different cache formats. These are limited to `DynamicCache` caches until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache):
if not isinstance(past_key_values.caches[0], DynamicCache):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`"
)
past_key_values = self._reorder_cache(past_key_values, beam_idx)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# Standard code path: use the `Cache.reorder_cache`
past_key_values = ModelCache.from_legacy_cache(past_key_values)
# Standard code path: use the cache's `.reorder_cache`
else:
past_key_values.reorder_cache(beam_idx)
return past_key_values
Expand Down
110 changes: 50 additions & 60 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand Down Expand Up @@ -326,14 +325,12 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

if self.past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -394,7 +391,6 @@ def forward(
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand All @@ -416,14 +412,17 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
if self.past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if cache_position is not None: # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails
key_states, value_states = key_states[:, :, :cache_position[-1]+1, :], value_states[:, :, :cache_position[-1]+1, :]
key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs)
if (
cache_position is not None
): # we slice for static kv cache to be supported in FA2. Not sure it's a must as compile fails
key_states, value_states = (
key_states[:, :, : cache_position[-1] + 1, :],
value_states[:, :, : cache_position[-1] + 1, :],
)

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
Expand Down Expand Up @@ -583,7 +582,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand All @@ -598,7 +596,6 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_value=past_key_value,
output_attentions=output_attentions,
)

Expand All @@ -615,12 +612,10 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

past_key_value = getattr(self, "past_key_value", past_key_value)

if self.past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = self.past_key_value.update(key_states, value_states, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -676,7 +671,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand Down Expand Up @@ -705,7 +699,6 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_value=past_key_value,
output_attentions=output_attentions,
**kwargs,
)
Expand Down Expand Up @@ -797,14 +790,38 @@ def _setup_cache(
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

def _setup_cache_from_past_key_values(self, past_key_values: List[torch.FloatTensor] | ModelCache):
"""
Sets the model cache from the `past_key_values` argument, if needed.
Cache setting priority, when `use_cache` is `True`
1. if `past_key_values` is passed, setup the cache using it. Note that if it is not a `ModelCache`, it is
assumed the user expect the legacy API, where the model instance does NOT hold the cache.
2. if `past_key_values` is not passed, use a previously set up cache when it exists
3. otherwise, set up a new dynamic cache
"""
if past_key_values is not None:
if not isinstance(past_key_values, ModelCache):
past_key_values = ModelCache.from_legacy_cache(past_key_values)
self._setup_cache(external_cache=past_key_values)
elif self._get_cache() is None:
self._setup_cache(cache_cls=DynamicCache)

def _reset_cache(self):
model = getattr(self, "model", self)
for layer in model.layers:
layer.self_attn.past_key_value = None

def _has_cache(self):
def _get_cache(self, all_layers: bool = False) -> Cache | ModelCache:
"""
Returns the `Cache` from the first layer or, if `all_layers` is `True`, a `ModelCache` instance containing
all layers' caches
"""
model = getattr(self, "model", self)
return model.layers[0].self_attn.past_key_value is not None
if all_layers:
return ModelCache([layer.self_attn.past_key_value for layer in model.layers])
else:
return model.layers[0].self_attn.past_key_value


LLAMA_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -918,7 +935,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor] | ModelCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand All @@ -943,30 +960,15 @@ def forward(
)
use_cache = False

# Cache setting priority, when `use_cache` is `True`
# 1. if `past_key_values` is passed, setup the cache using it. Note that if it is not a `ModelCache`, it is
# assumed the user expect the legacy API, where the model instance does NOT hold the cache.
# 2. if `past_key_values` is not passed, use a previously set up cache
# 3. otherwise, set up a dynamic cache
legacy_cache = False
legacy_cache = past_key_values is not None and not isinstance(past_key_values, ModelCache)
past_seen_tokens = 0
if use_cache:
if past_key_values is not None:
if not isinstance(past_key_values, ModelCache):
past_key_values = ModelCache.from_legacy_cache(past_key_values)
legacy_cache = True
self._setup_cache(external_cache=past_key_values)
elif not self._has_cache():
self._setup_cache(cache_cls=DynamicCache)
self._setup_cache_from_past_key_values(past_key_values)
past_seen_tokens = self._get_cache().get_seq_length()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, (StaticCache)):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
Expand Down Expand Up @@ -995,7 +997,6 @@ def forward(
causal_mask,
position_ids,
cache_position,
past_key_values,
output_attentions,
)
else:
Expand All @@ -1004,7 +1005,6 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_value=past_key_values,
output_attentions=output_attentions,
)

Expand All @@ -1020,10 +1020,9 @@ def forward(
all_hidden_states += (hidden_states,)

next_cache = None
# if use_cache and isinstance(next_decoder_cache, (DynamicCache, SinkCache)):
if use_cache:
next_cache = ModelCache([layer.self_attn.past_key_value for layer in self.layers])
if legacy_cache:
next_cache = self._get_cache(all_layers=True)
if legacy_cache: # Legacy behavior: the model does NOT hold the cache between forward passes
self._reset_cache()

if not return_dict:
Expand Down Expand Up @@ -1207,16 +1206,13 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, ModelCache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
self._setup_cache_from_past_key_values(past_key_values)
cache = self._get_cache()
cache_length = cache.get_seq_length() # number of valid tokens in the cache
past_length = cache.seen_tokens # number of tokens that went through the model (may be > than `cache_length`)
max_cache_length = cache.get_max_length() # cache maximum length, if it is a limited size cache

if past_length > 0:
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
Expand Down Expand Up @@ -1245,12 +1241,6 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
past_length = past_key_value.get_seq_length()
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
Expand Down
Loading

0 comments on commit 1f88b6d

Please sign in to comment.