Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix] Fix bug for triton 2.1.0 llama decode kernel and add llama2 decode attention kernel. #113

Merged
merged 40 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2fd173c
support Qwen-7b model
Aug 10, 2023
b0e3701
modify README.md to support Qwen-7b
Aug 10, 2023
ad16efd
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 10, 2023
b954e03
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 11, 2023
fa1afb7
fix bug for start up llama2
Aug 11, 2023
f109314
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 11, 2023
3887299
update context_flashattention triton kernel to support V100
Aug 11, 2023
e3d82e3
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 11, 2023
22d5a1e
fix Qwen-7b start error
Aug 11, 2023
94286b9
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 11, 2023
fcfff28
Refactoring the implementation of throughput analysis logs
Aug 11, 2023
e710133
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 11, 2023
62c34d5
fix Qwen-7b --tp 2 load weights bug
hiworldwzj Aug 11, 2023
a5cbcbb
add support for qwen-7b 8k
hiworldwzj Aug 11, 2023
af0bda7
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 12, 2023
2112539
Merge branch 'ModelTC:main' into main
hiworldwzj Aug 14, 2023
7a9992d
Merge branch 'main' of https://github.com/hiworldwzj/lightllm into main
Aug 17, 2023
914b41c
Refactor the code to make it easier to add models.
Aug 17, 2023
4fbba83
use template to build
Aug 18, 2023
0dd72ac
support llama
Aug 18, 2023
a38b1a9
support llama2
Aug 18, 2023
3b07460
support qwen
Aug 18, 2023
f05e009
modify weight name
Aug 18, 2023
b4dba44
update
Aug 18, 2023
b9a16c2
Refactoring Chatglm
Aug 17, 2023
84b4c4f
refactor starcoder
Aug 18, 2023
e87af60
solve conflict
Aug 18, 2023
a81ebab
refactor starcoder
Aug 18, 2023
f659b0a
refactor chatglm
Aug 18, 2023
19a5112
rename the infer file
Aug 18, 2023
a94224d
add test case
Aug 18, 2023
d1b8b3b
update
Aug 18, 2023
27e559e
Merge branch 'main' of https://github.com/hiworldwzj/lightllm into main
Aug 21, 2023
61b4917
fix bug for aborted req
Aug 21, 2023
7cedd31
Merge branch 'main' of https://github.com/hiworldwzj/lightllm into main
Aug 22, 2023
7d8ebc7
Merge branch 'main' of https://github.com/hiworldwzj/lightllm into main
Aug 24, 2023
60e1666
Merge branch 'main' of https://github.com/hiworldwzj/lightllm into main
Aug 28, 2023
6e46502
update llama kernel support for triton2.1.0, better performance
Aug 28, 2023
0292b1b
Merge branch 'main' of https://github.com/hiworldwzj/lightllm into main
Aug 31, 2023
6563e4f
fix bug for triton 2.1.0 llama decode kernel and add llama2 decode at…
Aug 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_l
infer_state.b_loc = b_loc
infer_state.b_start_loc = b_start_loc
infer_state.b_seq_len = b_seq_len
infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True)

infer_state.mem_manager = self.mem_manager
infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num)
infer_state.prefill_key_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index)

infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True)
predict_logics = self._context_forward(input_ids, infer_state)
return predict_logics

Expand All @@ -161,7 +161,6 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_lo
infer_state.b_seq_len = b_seq_len

infer_state.mem_manager = self.mem_manager
infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False)

alloc_mem = self.mem_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
Expand All @@ -178,6 +177,7 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_lo
infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index

infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False)
predict_logics = self._token_forward(input_ids, infer_state)
return predict_logics

Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def init_some_extra_state(self,
else:
self.position_cos = torch.index_select(model._cos_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1)
self.other_kv_index = b_loc[0, -1].item()
self.other_kv_index = b_loc[0, max_len_in_batch - 1].item()
return
44 changes: 30 additions & 14 deletions lightllm/models/llama2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.functional as F
import torch.distributed as dist
import numpy as np
import triton

from lightllm.models.llama2.layer_weights.transformer_layer_weight import Llama2TransformerLayerWeight
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd
Expand Down Expand Up @@ -48,19 +49,34 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo):
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)

if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch)
att_m_tensor = None

prob = torch.empty_like(att_m_tensor)
token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch)
att_m_tensor = None
o_tensor = torch.empty_like(q)

o_tensor = torch.empty_like(q)

token_att_fwd2(prob,
infer_state.mem_manager.value_buffer[self.layer_num_],
o_tensor.view(calcu_shape1),
infer_state.b_loc,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)
prob = None
return o_tensor
token_att_fwd2(prob,
infer_state.mem_manager.value_buffer[self.layer_num_],
o_tensor.view(calcu_shape1),
infer_state.b_loc,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)
prob = None
return o_tensor
elif triton.__version__ >= "2.1.0":
o_tensor = torch.empty_like(q)
from lightllm.models.llama2.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd
token_softmax_reducev_fwd(att_m_tensor,
infer_state.mem_manager.value_buffer[self.layer_num_],
o_tensor.view(calcu_shape1),
infer_state.b_loc,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
infer_state.other_kv_index)
return o_tensor
else:
raise Exception("not support triton version")
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch

import triton
import triton.language as tl
import torch.nn.functional as F


@triton.jit
def _fwd_kernel(
Logics, V, Out,
B_Loc, B_Start_Loc, B_Seqlen, max_input_len,
stride_logic_h, stride_logic_bs,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
stride_b_loc_b, stride_b_loc_s,
other_kv_index, # 避免读取到nan的数据
kv_group_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)

cur_kv_head = cur_head // kv_group_num

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)

offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)

off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s

v_ptrs = V + off_v

e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)

for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=other_kv_index)

qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
mask=start_n + offs_n < cur_batch_seq_len, other=float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max

acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
return


@torch.no_grad()
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head)
kv_group_num = logics.shape[0] // v.shape[1]

num_warps = 1
_fwd_kernel[grid](
logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len,
logics.stride(0), logics.stride(1),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
b_loc.stride(0), b_loc.stride(1),
other_kv_index,
kv_group_num,
BLOCK_DMODEL=v.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3
)
return