Skip to content

Commit

Permalink
Merge pull request #38 from jiqing-feng/main
Browse files Browse the repository at this point in the history
fix llama kv cache
  • Loading branch information
Viol2000 authored Jan 9, 2024
2 parents 9825f98 + a24c048 commit 1afa531
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
12 changes: 6 additions & 6 deletions lade/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,10 @@ def copy_from_last():
past_key_values = []
for idx, kv in enumerate(outputs.past_key_values):
for hh in range(max_hit):
assert outputs.step_len == kv[0].size(2)
kv[0][:,:,outputs.kvcache_len + hh,:] = kv[0][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
kv[1][:,:,outputs.kvcache_len + hh,:] = kv[1][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
assert outputs.step_len == kv[idx][0].size(2)
kv[idx][0][:,:,outputs.kvcache_len + hh,:] = kv[idx][0][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
kv[idx][1][:,:,outputs.kvcache_len + hh,:] = kv[idx][1][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
past_key_values.append( (kv[idx][0][:,:,:outputs.kvcache_len + max_hit,:], kv[idx][1][:,:,:outputs.kvcache_len + max_hit,:]) )
outputs.past_key_values = past_key_values

else:
Expand All @@ -435,8 +435,8 @@ def copy_from_last():
past_key_values = []
for idx, kv in enumerate(outputs.past_key_values):
for hh in range(max_hit):
assert outputs.step_len == kv[0].size(2)
past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
assert outputs.step_len == kv[idx][0].size(2)
past_key_values.append( (kv[idx][0][:,:,:outputs.kvcache_len + max_hit,:], kv[idx][1][:,:,:outputs.kvcache_len + max_hit,:]) )
outputs.past_key_values = past_key_values


Expand Down
34 changes: 16 additions & 18 deletions lade/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch, math, time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast, CausalLMOutputWithPast, _expand_mask

def j_make_causal_mask_multilevel(
Expand Down Expand Up @@ -168,9 +169,11 @@ def LlamaModeljforward(
seq_length_with_past = seq_length
past_key_values_length = 0

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -217,29 +220,24 @@ def LlamaModeljforward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
)
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer.forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)

hidden_states = layer_outputs[0]
Expand Down

0 comments on commit 1afa531

Please sign in to comment.