Skip to content

Commit

Permalink
support hf v4.36
Browse files Browse the repository at this point in the history
  • Loading branch information
Viol2000 committed Jan 9, 2024
1 parent 1afa531 commit 7571393
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 39 deletions.
44 changes: 10 additions & 34 deletions lade/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,40 +411,19 @@ def copy_from_last():
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat((attention_mask, torch.ones(1, max_hit, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1)

#support awq

if not USE_AWQ:
past_key_values = []
for idx, kv in enumerate(outputs.past_key_values):
for hh in range(max_hit):
assert outputs.step_len == kv[idx][0].size(2)
kv[idx][0][:,:,outputs.kvcache_len + hh,:] = kv[idx][0][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
kv[idx][1][:,:,outputs.kvcache_len + hh,:] = kv[idx][1][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
past_key_values.append( (kv[idx][0][:,:,:outputs.kvcache_len + max_hit,:], kv[idx][1][:,:,:outputs.kvcache_len + max_hit,:]) )
outputs.past_key_values = past_key_values

else:

#not support awq
#print("kv: ", outputs.past_key_values)
assert not USE_AWQ
past_key_values = []
for idx, kv in enumerate(outputs.past_key_values):
for hh in range(max_hit):
#print("cache: ", outputs.kvcache_len, max_hit, outputs.step_len, window_cache[0].k.size(), window_cache[0].v.size())
for idx, kv in enumerate(window_cache):
kv.k[:,:,:,outputs.kvcache_len + hh,:] = kv.k[:,:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
kv.v[:,:,outputs.kvcache_len + hh,:] = kv.v[:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]


past_key_values = []
for idx, kv in enumerate(outputs.past_key_values):
for hh in range(max_hit):
assert outputs.step_len == kv[idx][0].size(2)
past_key_values.append( (kv[idx][0][:,:,:outputs.kvcache_len + max_hit,:], kv[idx][1][:,:,:outputs.kvcache_len + max_hit,:]) )
outputs.past_key_values = past_key_values

assert outputs.step_len == kv[0].size(2)
kv[0][:,:,outputs.kvcache_len + hh,:] = kv[0][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
kv[1][:,:,outputs.kvcache_len + hh,:] = kv[1][:,:,outputs.step_len-len(guess_tokens)+hit_point * GUESS_SIZE + hh,:]
past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
outputs.past_key_values = past_key_values

lst_token = hits[max_hit]
def sublist(lst1, lst2):
ls1 = [element for element in lst1 if element in lst2]
ls2 = [element for element in lst2 if element in lst1]
return ls1 == ls2

for hh in range(max_hit + 1):
if eos_token_id is not None and hits[hh] == eos_token_id[0]:
Expand All @@ -455,9 +434,6 @@ def sublist(lst1, lst2):
max_hit = hh
break
else:
#
#
#
all_old_tokens.append(hits[hh])

if chat:
Expand Down
10 changes: 6 additions & 4 deletions lade/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def LlamaModeljforward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None

for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand All @@ -243,7 +243,7 @@ def LlamaModeljforward(
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)
Expand All @@ -254,7 +254,9 @@ def LlamaModeljforward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.34.0
transformers==4.36.2
accelerate==0.23.0
fschat==0.2.31
openai
Expand Down

0 comments on commit 7571393

Please sign in to comment.