Skip to content

Commit

Permalink
[Bug Fix] Fix bug for triton 2.1.0 llama decode kernel and add llama2…
Browse files Browse the repository at this point in the history
… decode attention kernel. (#113)
  • Loading branch information
hiworldwzj authored Aug 31, 2023
1 parent 718e6d6 commit 2d3cd33
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 17 deletions.
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_l
infer_state.b_loc = b_loc
infer_state.b_start_loc = b_start_loc
infer_state.b_seq_len = b_seq_len
infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True)

infer_state.mem_manager = self.mem_manager
infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num)
infer_state.prefill_key_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index)

infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True)
predict_logics = self._context_forward(input_ids, infer_state)
return predict_logics

Expand All @@ -161,7 +161,6 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_lo
infer_state.b_seq_len = b_seq_len

infer_state.mem_manager = self.mem_manager
infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False)

alloc_mem = self.mem_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
Expand All @@ -178,6 +177,7 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_lo
infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index

infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False)
predict_logics = self._token_forward(input_ids, infer_state)
return predict_logics

Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def init_some_extra_state(self,
else:
self.position_cos = torch.index_select(model._cos_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1)
self.other_kv_index = b_loc[0, -1].item()
self.other_kv_index = b_loc[0, max_len_in_batch - 1].item()
return
44 changes: 30 additions & 14 deletions lightllm/models/llama2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.functional as F
import torch.distributed as dist
import numpy as np
import triton

from lightllm.models.llama2.layer_weights.transformer_layer_weight import Llama2TransformerLayerWeight
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd
Expand Down Expand Up @@ -48,19 +49,34 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo):
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)

if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch)
att_m_tensor = None

prob = torch.empty_like(att_m_tensor)
token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch)
att_m_tensor = None
o_tensor = torch.empty_like(q)

o_tensor = torch.empty_like(q)

token_att_fwd2(prob,
infer_state.mem_manager.value_buffer[self.layer_num_],
o_tensor.view(calcu_shape1),
infer_state.b_loc,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)
prob = None
return o_tensor
token_att_fwd2(prob,
infer_state.mem_manager.value_buffer[self.layer_num_],
o_tensor.view(calcu_shape1),
infer_state.b_loc,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)
prob = None
return o_tensor
elif triton.__version__ >= "2.1.0":
o_tensor = torch.empty_like(q)
from lightllm.models.llama2.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd
token_softmax_reducev_fwd(att_m_tensor,
infer_state.mem_manager.value_buffer[self.layer_num_],
o_tensor.view(calcu_shape1),
infer_state.b_loc,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
infer_state.other_kv_index)
return o_tensor
else:
raise Exception("not support triton version")
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch

import triton
import triton.language as tl
import torch.nn.functional as F


@triton.jit
def _fwd_kernel(
Logics, V, Out,
B_Loc, B_Start_Loc, B_Seqlen, max_input_len,
stride_logic_h, stride_logic_bs,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
stride_b_loc_b, stride_b_loc_s,
other_kv_index, # 避免读取到nan的数据
kv_group_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)

cur_kv_head = cur_head // kv_group_num

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)

offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)

off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s

v_ptrs = V + off_v

e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)

for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=other_kv_index)

qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
mask=start_n + offs_n < cur_batch_seq_len, other=float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max

acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
return


@torch.no_grad()
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head)
kv_group_num = logics.shape[0] // v.shape[1]

num_warps = 1
_fwd_kernel[grid](
logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len,
logics.stride(0), logics.stride(1),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
b_loc.stride(0), b_loc.stride(1),
other_kv_index,
kv_group_num,
BLOCK_DMODEL=v.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3
)
return

0 comments on commit 2d3cd33

Please sign in to comment.