Skip to content
Closed
37 changes: 37 additions & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,3 +683,40 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla,
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()

@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False

num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)

hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)

metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)
70 changes: 70 additions & 0 deletions tests/ut/quantization/test_w8a8_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest.mock import MagicMock, patch

import torch

from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all


class TestAscendW8A8FusedMoEMethod(TestBase):

def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens, self.hidden_size)

@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu,
mock_grouped_matmul,
mock_moe_re_routing,
mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder

fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
2 changes: 2 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, vllm_config):
self.expert_map_path = additional_config.get("expert_map_path", None)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
self.enable_prefill_optimizations = additional_config.get(
"enable_prefill_optimizations", False)


class TorchairGraphConfig:
Expand Down
55 changes: 42 additions & 13 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ def __init__(
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_prefill_optimizations = \
ascend_config.enable_prefill_optimizations and not self.torchair_graph_enabled

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
Expand All @@ -601,6 +603,8 @@ def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
if hasattr(self, "running_in_graph") and not self.running_in_graph:
return x
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
x,
Expand Down Expand Up @@ -871,14 +875,7 @@ def _forward_prefill(
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output, is_prefill=True)[0]
return attn_output

def exec_kv(
self,
Expand Down Expand Up @@ -1208,6 +1205,12 @@ def forward(
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
if not self.running_in_graph:
o_proj_input_shape = (num_actual_toks,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
Expand All @@ -1218,11 +1221,12 @@ def forward(
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
output[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
output[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:] = output_prefill

if has_decode:
if self.running_in_graph:
Expand All @@ -1239,9 +1243,34 @@ def forward(
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else:
output[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode

current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)

_o = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_prefill_optimizations)[0]
output[...] = _o
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_prefill_optimizations)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
return output_padded
85 changes: 74 additions & 11 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
Expand All @@ -160,7 +161,13 @@ def forward(
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
num_tokens = output_parallel.shape[0]
if is_force_scatter and num_tokens % self.tp_size:
output_parallel = nn.functional.pad(
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
if is_force_scatter or (not is_prefill
and output_parallel.shape[0] % self.tp_size
== 0):
output = tensor_model_parallel_reduce_scatter(output_parallel,
dim=0)
else:
Expand All @@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
Expand Down Expand Up @@ -347,13 +355,16 @@ def __init__(
reduce_results = not self.all_reduce_merge
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
enable_prefill_optimizations = \
ascend_config.enable_prefill_optimizations and not ascend_config.torchair_graph_config.enabled
self.shared_experts = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=reduce_results,
force_replicate=self.enable_multistream_moe,
force_replicate=self.enable_multistream_moe
or enable_prefill_optimizations,
prefix=f"{prefix}.shared_experts",
)
else:
Expand Down Expand Up @@ -447,9 +458,9 @@ def __init__(
self.kv_lora_rank = kv_lora_rank

self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size

self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
Expand All @@ -462,6 +473,8 @@ def __init__(
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
self.enable_prefill_optimizations = \
ascend_config.enable_prefill_optimizations and not self.torchair_graph_enabled

if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
Expand Down Expand Up @@ -501,8 +514,9 @@ def __init__(
prefix=f"{prefix}.kv_b_proj")
if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and
ascend_config.torchair_graph_config.enable_multistream_moe):
and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe
or self.enable_prefill_optimizations)):
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
Expand Down Expand Up @@ -596,13 +610,33 @@ def forward(
output = output.view(-1, output_shape[-1])
return output
else:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.enable_prefill_optimizations and self.debug_layer_idx > 3 and self.debug_layer_idx < 61:
hidden_states_or_q_c = get_tp_group().all_gather(
hidden_states_or_q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:
num_tokens = hidden_states_or_q_c.shape[0]

kv_c, k_pe = kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if not self.enable_prefill_optimizations or self.debug_layer_idx < 3:
output_shape = hidden_states.shape
else:
num_tokens = hidden_states_or_q_c.shape[0]
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
output_shape=output_shape)


class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
Expand Down Expand Up @@ -677,6 +711,9 @@ def __init__(
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
self.enable_prefill_optimizations = \
ascend_config.enable_prefill_optimizations and not ascend_config.torchair_graph_config.enabled

def forward(
self,
Expand Down Expand Up @@ -731,6 +768,17 @@ def forward(
# first layer.
residual *= 1. / self.routed_scaling_factor

tp_size = get_tensor_model_parallel_world_size()
if self.enable_prefill_optimizations and (
self.layer_idx == 3 or self.layer_idx == 61) and tp_size > 1:
num_tokens, _ = residual.shape
if num_tokens % tp_size:
residual = nn.functional.pad(residual,
(0, 0, 0, -num_tokens % tp_size))
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
residual = chunk_residual[tp_rank]

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
Expand All @@ -756,6 +804,21 @@ def forward(
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)

# for last layer of main model and mtp layer.
if self.enable_prefill_optimizations and self.layer_idx >= 60 and tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
residual = get_tp_group().all_gather(residual, 0)

attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:
num_tokens = hidden_states.shape[0]

if num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:num_tokens]
residual = residual[:num_tokens]

return hidden_states, residual


Expand Down
Loading
Loading