diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3be9299a0..d2a3140f8 100644 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -138,7 +138,6 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_l infer_state.b_loc = b_loc infer_state.b_start_loc = b_start_loc infer_state.b_seq_len = b_seq_len - infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True) infer_state.mem_manager = self.mem_manager infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num) @@ -146,6 +145,7 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_l infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index) + infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True) predict_logics = self._context_forward(input_ids, infer_state) return predict_logics @@ -161,7 +161,6 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_lo infer_state.b_seq_len = b_seq_len infer_state.mem_manager = self.mem_manager - infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False) alloc_mem = self.mem_manager.alloc_contiguous(batch_size) if alloc_mem is not None: @@ -178,6 +177,7 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_lo infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index + infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False) predict_logics = self._token_forward(input_ids, infer_state) return predict_logics diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index f2d641b61..09a3aa43e 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -29,5 +29,5 @@ def init_some_extra_state(self, else: self.position_cos = torch.index_select(model._cos_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1) - self.other_kv_index = b_loc[0, -1].item() + self.other_kv_index = b_loc[0, max_len_in_batch - 1].item() return diff --git a/lightllm/models/llama2/layer_infer/transformer_layer_infer.py b/lightllm/models/llama2/layer_infer/transformer_layer_infer.py index 68e7797b4..3b994d76a 100644 --- a/lightllm/models/llama2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama2/layer_infer/transformer_layer_infer.py @@ -2,6 +2,7 @@ import torch.functional as F import torch.distributed as dist import numpy as np +import triton from lightllm.models.llama2.layer_weights.transformer_layer_weight import Llama2TransformerLayerWeight from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd @@ -48,19 +49,34 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo): infer_state.b_start_loc, infer_state.b_seq_len, infer_state.max_len_in_batch) + + if triton.__version__ == "2.0.0": + prob = torch.empty_like(att_m_tensor) + token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch) + att_m_tensor = None - prob = torch.empty_like(att_m_tensor) - token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch) - att_m_tensor = None + o_tensor = torch.empty_like(q) - o_tensor = torch.empty_like(q) - - token_att_fwd2(prob, - infer_state.mem_manager.value_buffer[self.layer_num_], - o_tensor.view(calcu_shape1), - infer_state.b_loc, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch) - prob = None - return o_tensor + token_att_fwd2(prob, + infer_state.mem_manager.value_buffer[self.layer_num_], + o_tensor.view(calcu_shape1), + infer_state.b_loc, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch) + prob = None + return o_tensor + elif triton.__version__ >= "2.1.0": + o_tensor = torch.empty_like(q) + from lightllm.models.llama2.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + token_softmax_reducev_fwd(att_m_tensor, + infer_state.mem_manager.value_buffer[self.layer_num_], + o_tensor.view(calcu_shape1), + infer_state.b_loc, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + infer_state.other_kv_index) + return o_tensor + else: + raise Exception("not support triton version") diff --git a/lightllm/models/llama2/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/llama2/triton_kernel/token_attention_softmax_and_reducev.py new file mode 100644 index 000000000..7f4208389 --- /dev/null +++ b/lightllm/models/llama2/triton_kernel/token_attention_softmax_and_reducev.py @@ -0,0 +1,84 @@ +import torch + +import triton +import triton.language as tl +import torch.nn.functional as F + + +@triton.jit +def _fwd_kernel( + Logics, V, Out, + B_Loc, B_Start_Loc, B_Seqlen, max_input_len, + stride_logic_h, stride_logic_bs, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + stride_b_loc_b, stride_b_loc_s, + other_kv_index, # 避免读取到nan的数据 + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s + + v_ptrs = V + off_v + + e_max = float("-inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=other_kv_index) + + qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=start_n + offs_n < cur_batch_seq_len, other=float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + e_sum = e_sum * old_scale + tl.sum(p, 0) + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) + e_max = n_e_max + + acc = acc / e_sum + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + +@torch.no_grad() +def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): + BLOCK = 64 + batch, head = b_seq_len.shape[0], logics.shape[0] + grid = (batch, head) + kv_group_num = logics.shape[0] // v.shape[1] + + num_warps = 1 + _fwd_kernel[grid]( + logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, + logics.stride(0), logics.stride(1), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + b_loc.stride(0), b_loc.stride(1), + other_kv_index, + kv_group_num, + BLOCK_DMODEL=v.shape[-1], + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=3 + ) + return \ No newline at end of file