forked from InternLM/lmdeploy
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor ascend kernels (InternLM#2355)
* support ascend using infer_ext * fix(ascend): make infer_ext using TND format q,k,v in paged_token_attention * support ascend using infer_ext * feat: support ascend moe_gating_topk_softmax * feat: change infer_ext ops function param order (#2) * ascend: align attention mask to 32bytes (#7) * fix attn args (#9) * fix: expand shape of attn_mask (#10) * feat: udpate infer_ext ops interface (#13) * rename infer_ext to dlinfer * format code * Support internlm 2.5 (#14) * refactor ascend pagedattention * fix ascend apply_rotary_pos_emb * fix import dlinfer (#16) * fix: fix rms_norm params (#18) * fix sync on ascend --------- Co-authored-by: chenchiyu <chenchiyu@pjlab.org.cn> Co-authored-by: CyCle1024 <ccy_justin@163.com> Co-authored-by: Wei Tao <1136862851@qq.com> Co-authored-by: jinminxi104 <jinminxi104@hotmail.com> Co-authored-by: pdx1989 <pdx1989@gmail.com>
- Loading branch information
1 parent
47c7379
commit 32c8580
Showing
17 changed files
with
285 additions
and
293 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,18 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from ..dipu import (apply_rotary_pos_emb, fill_kv_cache, fused_rotary_emb, | ||
multinomial_sampling, paged_attention_fwd, rms_norm) | ||
from ..default import multinomial_sampling | ||
from .apply_rotary_pos_emb import apply_rotary_pos_emb | ||
from .fill_kv_cache import fill_kv_cache | ||
from .fused_rotary_emb import fused_rotary_emb | ||
from .moe_gating_topk_softmax import moe_gating_topk_softmax | ||
from .pagedattention import paged_attention_fwd | ||
from .rms_norm import rms_norm | ||
|
||
__all__ = [ | ||
'rms_norm', | ||
'apply_rotary_pos_emb', | ||
'fused_rotary_emb', | ||
'fill_kv_cache', | ||
'paged_attention_fwd', | ||
'moe_gating_topk_softmax', | ||
'multinomial_sampling', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
from torch import Tensor | ||
|
||
|
||
def apply_rotary_pos_emb( | ||
query_states: Tensor, | ||
key_states: Tensor, | ||
cos: Tensor, | ||
sin: Tensor, | ||
position_ids: Tensor, | ||
position_ids_1d: Tensor, | ||
q_embed=None, | ||
k_embed=None, | ||
context=None, | ||
): | ||
bs, head, dim = query_states.shape | ||
num_kv_heads = key_states.shape[1] | ||
query_states_reshaped = query_states.reshape(1, bs, head, dim) | ||
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) | ||
if not (hasattr(context, 'cos') or hasattr(context, 'sin')): | ||
if len(cos.shape) == 3 and len(sin.shape) == 3: | ||
cos = cos[:, position_ids_1d].view(1, bs, 1, -1) | ||
sin = sin[:, position_ids_1d].view(1, bs, 1, -1) | ||
elif len(cos.shape) == 2 and len(sin.shape) == 2: | ||
cos = cos[position_ids_1d].view(1, bs, 1, -1) | ||
sin = sin[position_ids_1d].view(1, bs, 1, -1) | ||
else: | ||
raise RuntimeError('Cannot handle cos/sin shape dims!') | ||
|
||
if context: | ||
setattr(context, 'cos', cos) | ||
setattr(context, 'sin', sin) | ||
cached_cos = context.cos if context else cos | ||
cached_sin = context.sin if context else sin | ||
query_states, key_states = ext_ops.apply_rotary_pos_emb( | ||
query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, | ||
None, None) | ||
query_states = query_states.view(bs, head, dim) | ||
key_states = key_states.view(bs, num_kv_heads, dim) | ||
if q_embed is None: | ||
q_embed = query_states | ||
else: | ||
q_embed.copy_(query_states) | ||
if k_embed is None: | ||
k_embed = key_states | ||
else: | ||
k_embed.copy_(key_states) | ||
return q_embed, k_embed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
from torch import Tensor | ||
|
||
|
||
def fill_kv_cache( | ||
key_states: Tensor, | ||
value_states: Tensor, | ||
key_caches: Tensor, | ||
value_caches: Tensor, | ||
q_start_loc: Tensor, | ||
q_seq_length: Tensor, | ||
kv_seq_length: Tensor, | ||
max_q_seq_length: int, | ||
block_offsets: Tensor, | ||
context: None, | ||
): | ||
"""fill key/value state to cache for paged attention.""" | ||
ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, | ||
context.kv_start_indices) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def fused_rotary_emb( | ||
query_states: Tensor, | ||
key_states: Tensor, | ||
position_ids: torch.LongTensor, | ||
inv_freq: Tensor, | ||
scaling_factor: float, | ||
out_q: Tensor = None, | ||
out_k: Tensor = None, | ||
context=None, | ||
): | ||
batch, seqlen, head, dim = query_states.shape | ||
num_kv_heads = key_states.shape[-2] | ||
query_states_reshaped = query_states.view(batch, seqlen, head, dim) | ||
key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim) | ||
position_ids = position_ids.squeeze(0).unsqueeze(-1) | ||
pos_freq = position_ids / scaling_factor * inv_freq | ||
if not (hasattr(context, 'cos') or hasattr(context, 'sin')): | ||
cos = (torch.cos(pos_freq).view(batch, seqlen, 1, | ||
-1).repeat(1, 1, 1, | ||
2).to(query_states.dtype)) | ||
sin = (torch.sin(pos_freq).view(batch, seqlen, 1, | ||
-1).repeat(1, 1, 1, | ||
2).to(query_states.dtype)) | ||
if context: | ||
setattr(context, 'cos', cos) | ||
setattr(context, 'sin', sin) | ||
cached_cos = context.cos if context else cos | ||
cached_sin = context.sin if context else sin | ||
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, | ||
cached_cos, cached_sin, None, None) | ||
if out_q is None: | ||
out_q = query_states | ||
else: | ||
out_q.copy_(query_states) | ||
if out_k is None: | ||
out_k = key_states | ||
else: | ||
out_k.copy_(key_states) | ||
return out_q, out_k |
10 changes: 10 additions & 0 deletions
10
lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def moe_gating_topk_softmax(router_logits: Tensor, topk: int): | ||
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax( | ||
router_logits, topk) | ||
return routing_weights.to(torch.float32), selected_experts.to(torch.int64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def prefill_attention( | ||
query_states: Tensor, | ||
key_states: Tensor, | ||
value_states: Tensor, | ||
attn_output: Tensor, | ||
key_cache: Tensor, | ||
value_cache: Tensor, | ||
block_offsets: Tensor, | ||
q_start_loc: Tensor, | ||
q_seq_len: Tensor, | ||
kv_seq_len: Tensor, | ||
block_size: int, | ||
kv_cache_len: int, | ||
context=None, | ||
): | ||
num_q_heads, dim = query_states.shape[1:3] | ||
num_kv_heads = value_states.shape[1] | ||
|
||
if context.is_unpaged_prefill: | ||
ext_ops.prefill_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
q_start_loc, | ||
q_seq_len, | ||
context.max_q_seq_length, | ||
num_q_heads, | ||
num_kv_heads, | ||
attn_mask=context.attention_mask, | ||
attn_output=attn_output, | ||
) | ||
else: | ||
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) | ||
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) | ||
ext_ops.paged_prefill_attention( | ||
query_states, | ||
key_cache, | ||
value_cache, | ||
block_offsets, | ||
block_size, | ||
q_start_loc, | ||
q_seq_len, | ||
kv_seq_len, | ||
num_q_heads, | ||
num_kv_heads, | ||
attn_mask=context.attention_mask, | ||
attn_output=attn_output, | ||
) | ||
|
||
|
||
def paged_decode_attention(q, k_cache, v_cache, attn_output, kv_seq_len, | ||
max_kv_seq_len, block_offsets, block_size): | ||
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] | ||
ext_ops.paged_decode_attention( | ||
q, | ||
k_cache, | ||
v_cache, | ||
block_offsets, | ||
block_size, | ||
kv_seq_len, | ||
max_kv_seq_len, | ||
num_q_heads, | ||
num_kv_heads, | ||
attn_output=attn_output.view(q.shape), | ||
) | ||
|
||
|
||
def paged_attention_fwd( | ||
query_states: Tensor, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
key_cache: Tensor, | ||
value_cache: Tensor, | ||
attn_output: Tensor, | ||
block_offsets: Tensor, | ||
q_start_loc: Tensor, | ||
q_seqlens: Tensor, | ||
kv_seqlens: Tensor, | ||
max_seqlen: int, | ||
window_size: int = 1, | ||
context=None, | ||
): | ||
is_decoding = query_states.shape[-3] == q_seqlens.size(0) | ||
block_num, block_size, head, dim = key_cache.size() | ||
kv_cache_len = block_num * block_size | ||
k = key_cache.reshape(block_num * block_size, head, dim) | ||
v = value_cache.reshape(block_num * block_size, head, dim) | ||
if not is_decoding: | ||
prefill_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
attn_output, | ||
k, | ||
v, | ||
block_offsets, | ||
q_start_loc, | ||
q_seqlens, | ||
kv_seqlens, | ||
block_size, | ||
kv_cache_len, | ||
context=context, | ||
) | ||
else: | ||
paged_decode_attention( | ||
query_states, | ||
k, | ||
v, | ||
attn_output, | ||
kv_seqlens, | ||
context.max_kv_seq_length, | ||
block_offsets, | ||
block_size, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
from torch import Tensor | ||
|
||
|
||
def rms_norm(hidden_states: Tensor, | ||
weight: Tensor, | ||
eps: float = 1e-6, | ||
out: Tensor = None): | ||
rms_norm_out = ext_ops.rms_norm(hidden_states, weight, eps) | ||
if out is None: | ||
out = rms_norm_out | ||
else: | ||
out.copy_(rms_norm_out) | ||
return out |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.