Skip to content

Commit

Permalink
feat: add sample_kv for CC method
Browse files Browse the repository at this point in the history
  • Loading branch information
niushengxiao committed Jan 7, 2025
1 parent a29f65d commit 2c2be5d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 281 deletions.
67 changes: 31 additions & 36 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
)
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import (
context_attention_fwd_with_v,
context_attention_fwd_no_prompt_cache_with_v,
)
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v
from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv

from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
Expand Down Expand Up @@ -137,11 +135,24 @@ def _context_attention_kernel_with_CC(
) -> torch.Tensor:
if infer_state.use_dynamic_prompt_cache:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
compressed_kv = self.alloc_tensor(
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
)
k_rope = self.alloc_tensor([infer_state.total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype)
sample_kv(
kv,
compressed_kv,
k_rope,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.req_manager.req_to_token_indexs,
)
else:
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
)

# CC
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
)
k_nope = self.alloc_tensor(
[compressed_kv.shape[0], q.shape[1], self.qk_nope_head_dim],
dtype=compressed_kv.dtype,
Expand All @@ -158,35 +169,19 @@ def _context_attention_kernel_with_CC(

q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
if infer_state.use_dynamic_prompt_cache:
context_attention_fwd_with_v(
q_nope,
q_rope,
k_nope,
k_rope,
v,
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
self.softmax_scale,
)
else:
context_attention_fwd_no_prompt_cache_with_v(
q_nope,
q_rope,
k_nope,
k_rope,
v,
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
self.softmax_scale,
)
context_attention_fwd_with_v(
q_nope,
q_rope,
k_nope,
k_rope,
v,
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
self.softmax_scale,
)
return o_tensor

def _context_attention_kernel_origin(
Expand Down
Loading

0 comments on commit 2c2be5d

Please sign in to comment.