Skip to content

Commit

Permalink
refactor ascend kernels (InternLM#2355)
Browse files Browse the repository at this point in the history
* support ascend using infer_ext

* fix(ascend): make infer_ext using TND format q,k,v in paged_token_attention

* support ascend using infer_ext

* feat: support ascend moe_gating_topk_softmax

* feat: change infer_ext ops function param order (#2)

* ascend: align attention mask to 32bytes (#7)

* fix attn args (#9)

* fix: expand shape of attn_mask (#10)

* feat: udpate infer_ext ops interface (#13)

* rename infer_ext to dlinfer

* format code

* Support internlm 2.5 (#14)

* refactor ascend pagedattention

* fix ascend apply_rotary_pos_emb

* fix import dlinfer (#16)

* fix: fix rms_norm params (#18)

* fix sync on ascend

---------

Co-authored-by: chenchiyu <chenchiyu@pjlab.org.cn>
Co-authored-by: CyCle1024 <ccy_justin@163.com>
Co-authored-by: Wei Tao <1136862851@qq.com>
Co-authored-by: jinminxi104 <jinminxi104@hotmail.com>
Co-authored-by: pdx1989 <pdx1989@gmail.com>
  • Loading branch information
6 people authored Aug 30, 2024
1 parent 47c7379 commit 32c8580
Show file tree
Hide file tree
Showing 17 changed files with 285 additions and 293 deletions.
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def check_env_deeplink(device_type: str):
if device_type in deeplink_device_type_list:
logger = get_logger('lmdeploy')
try:
import deeplink_ext # noqa: F401
import dlinfer.framework.lmdeploy_ext # noqa: F401
except Exception as e:
_handle_exception(e, 'PyTorch', logger)

Expand Down
16 changes: 12 additions & 4 deletions lmdeploy/pytorch/engine/devices/ascend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .dipu import DIPUDeviceUtils
from .base_device_utils import BaseDeviceUtils


class ASCENDDeviceUtils(DIPUDeviceUtils):
class ASCENDDeviceUtils(BaseDeviceUtils):

device = 'ascend'

Expand All @@ -17,7 +17,8 @@ def update_step_context(cls, step_context):
single_attention_mask = torch.logical_not(
torch.tril(
torch.ones(step_context.q_seq_length[i],
step_context.kv_seq_length[i],
step_context.block_offsets.shape[1] *
block_size,
dtype=torch.bool).cuda(),
diagonal=step_context.kv_seq_length[i] -
step_context.q_seq_length[i],
Expand All @@ -28,7 +29,7 @@ def update_step_context(cls, step_context):
block_loc = step_context.block_offsets[i][block_idx]
token_loc = history_length % block_size
for _ in range(step_context.q_seq_length[i]):
kv_start_indices.append(block_loc * block_size + token_loc)
kv_start_indices.append([block_loc * block_size + token_loc])
if _ == step_context.q_seq_length[i] - 1:
break
token_loc = (token_loc + 1) % block_size
Expand All @@ -38,4 +39,11 @@ def update_step_context(cls, step_context):
kv_start_indices, device=step_context.block_offsets.device)
setattr(step_context, 'kv_start_indices', kv_start_indices)
setattr(step_context, 'attention_mask', attention_mask)
setattr(step_context, 'q_start_loc', step_context.q_start_loc.cpu())
setattr(step_context, 'q_seq_length', step_context.q_seq_length.cpu())
setattr(step_context, 'kv_seq_length',
step_context.kv_seq_length.cpu())
is_unpaged_prefill = (not step_context.is_decoding) and all(
(step_context.q_seq_length == step_context.kv_seq_length).tolist())
setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill)
return step_context
13 changes: 0 additions & 13 deletions lmdeploy/pytorch/engine/devices/dipu.py

This file was deleted.

10 changes: 8 additions & 2 deletions lmdeploy/pytorch/kernels/ascend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..dipu import (apply_rotary_pos_emb, fill_kv_cache, fused_rotary_emb,
multinomial_sampling, paged_attention_fwd, rms_norm)
from ..default import multinomial_sampling
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .fused_rotary_emb import fused_rotary_emb
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm

__all__ = [
'rms_norm',
'apply_rotary_pos_emb',
'fused_rotary_emb',
'fill_kv_cache',
'paged_attention_fwd',
'moe_gating_topk_softmax',
'multinomial_sampling',
]
49 changes: 49 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def apply_rotary_pos_emb(
query_states: Tensor,
key_states: Tensor,
cos: Tensor,
sin: Tensor,
position_ids: Tensor,
position_ids_1d: Tensor,
q_embed=None,
k_embed=None,
context=None,
):
bs, head, dim = query_states.shape
num_kv_heads = key_states.shape[1]
query_states_reshaped = query_states.reshape(1, bs, head, dim)
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim)
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
if len(cos.shape) == 3 and len(sin.shape) == 3:
cos = cos[:, position_ids_1d].view(1, bs, 1, -1)
sin = sin[:, position_ids_1d].view(1, bs, 1, -1)
elif len(cos.shape) == 2 and len(sin.shape) == 2:
cos = cos[position_ids_1d].view(1, bs, 1, -1)
sin = sin[position_ids_1d].view(1, bs, 1, -1)
else:
raise RuntimeError('Cannot handle cos/sin shape dims!')

if context:
setattr(context, 'cos', cos)
setattr(context, 'sin', sin)
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
query_states, key_states = ext_ops.apply_rotary_pos_emb(
query_states_reshaped, key_states_reshaped, cached_cos, cached_sin,
None, None)
query_states = query_states.view(bs, head, dim)
key_states = key_states.view(bs, num_kv_heads, dim)
if q_embed is None:
q_embed = query_states
else:
q_embed.copy_(query_states)
if k_embed is None:
k_embed = key_states
else:
k_embed.copy_(key_states)
return q_embed, k_embed
20 changes: 20 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def fill_kv_cache(
key_states: Tensor,
value_states: Tensor,
key_caches: Tensor,
value_caches: Tensor,
q_start_loc: Tensor,
q_seq_length: Tensor,
kv_seq_length: Tensor,
max_q_seq_length: int,
block_offsets: Tensor,
context: None,
):
"""fill key/value state to cache for paged attention."""
ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches,
context.kv_start_indices)
45 changes: 45 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
import torch
from torch import Tensor


def fused_rotary_emb(
query_states: Tensor,
key_states: Tensor,
position_ids: torch.LongTensor,
inv_freq: Tensor,
scaling_factor: float,
out_q: Tensor = None,
out_k: Tensor = None,
context=None,
):
batch, seqlen, head, dim = query_states.shape
num_kv_heads = key_states.shape[-2]
query_states_reshaped = query_states.view(batch, seqlen, head, dim)
key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim)
position_ids = position_ids.squeeze(0).unsqueeze(-1)
pos_freq = position_ids / scaling_factor * inv_freq
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
cos = (torch.cos(pos_freq).view(batch, seqlen, 1,
-1).repeat(1, 1, 1,
2).to(query_states.dtype))
sin = (torch.sin(pos_freq).view(batch, seqlen, 1,
-1).repeat(1, 1, 1,
2).to(query_states.dtype))
if context:
setattr(context, 'cos', cos)
setattr(context, 'sin', sin)
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
cached_cos, cached_sin, None, None)
if out_q is None:
out_q = query_states
else:
out_q.copy_(query_states)
if out_k is None:
out_k = key_states
else:
out_k.copy_(key_states)
return out_q, out_k
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
import torch
from torch import Tensor


def moe_gating_topk_softmax(router_logits: Tensor, topk: int):
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(
router_logits, topk)
return routing_weights.to(torch.float32), selected_experts.to(torch.int64)
120 changes: 120 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/pagedattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
import torch
from torch import Tensor


def prefill_attention(
query_states: Tensor,
key_states: Tensor,
value_states: Tensor,
attn_output: Tensor,
key_cache: Tensor,
value_cache: Tensor,
block_offsets: Tensor,
q_start_loc: Tensor,
q_seq_len: Tensor,
kv_seq_len: Tensor,
block_size: int,
kv_cache_len: int,
context=None,
):
num_q_heads, dim = query_states.shape[1:3]
num_kv_heads = value_states.shape[1]

if context.is_unpaged_prefill:
ext_ops.prefill_attention(
query_states,
key_states,
value_states,
q_start_loc,
q_seq_len,
context.max_q_seq_length,
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask,
attn_output=attn_output,
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
ext_ops.paged_prefill_attention(
query_states,
key_cache,
value_cache,
block_offsets,
block_size,
q_start_loc,
q_seq_len,
kv_seq_len,
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask,
attn_output=attn_output,
)


def paged_decode_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
max_kv_seq_len, block_offsets, block_size):
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
ext_ops.paged_decode_attention(
q,
k_cache,
v_cache,
block_offsets,
block_size,
kv_seq_len,
max_kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output.view(q.shape),
)


def paged_attention_fwd(
query_states: Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
key_cache: Tensor,
value_cache: Tensor,
attn_output: Tensor,
block_offsets: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_seqlens: Tensor,
max_seqlen: int,
window_size: int = 1,
context=None,
):
is_decoding = query_states.shape[-3] == q_seqlens.size(0)
block_num, block_size, head, dim = key_cache.size()
kv_cache_len = block_num * block_size
k = key_cache.reshape(block_num * block_size, head, dim)
v = value_cache.reshape(block_num * block_size, head, dim)
if not is_decoding:
prefill_attention(
query_states,
key_states,
value_states,
attn_output,
k,
v,
block_offsets,
q_start_loc,
q_seqlens,
kv_seqlens,
block_size,
kv_cache_len,
context=context,
)
else:
paged_decode_attention(
query_states,
k,
v,
attn_output,
kv_seqlens,
context.max_kv_seq_length,
block_offsets,
block_size,
)
15 changes: 15 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def rms_norm(hidden_states: Tensor,
weight: Tensor,
eps: float = 1e-6,
out: Tensor = None):
rms_norm_out = ext_ops.rms_norm(hidden_states, weight, eps)
if out is None:
out = rms_norm_out
else:
out.copy_(rms_norm_out)
return out
16 changes: 0 additions & 16 deletions lmdeploy/pytorch/kernels/dipu/__init__.py

This file was deleted.

Loading

0 comments on commit 32c8580

Please sign in to comment.