Skip to content
1 change: 1 addition & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the note,please take a look, thanks. @wangxiyuan


The details of each config option are as follows:

Expand Down
37 changes: 37 additions & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,40 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla,
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()

@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False

num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)

hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)

metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)
7 changes: 7 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(self, vllm_config):
self.expert_map_path = additional_config.get("expert_map_path", None)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", True
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel


class TorchairGraphConfig:
Expand Down Expand Up @@ -166,6 +169,10 @@ def check_ascend_config(vllm_config, enforce_eager):
raise NotImplementedError(
"Torchair graph mode only works with following model types:"
f"{TORCHAIR_MODEL_LIST}.")
if ascend_config.enable_shared_expert_dp:
logger.warning(
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
"it has been disabled automatically.")
# aclgraph case
else:
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
Expand Down
54 changes: 40 additions & 14 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def __init__(
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
Expand All @@ -633,6 +634,8 @@ def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
if hasattr(self, "running_in_graph") and not self.running_in_graph:
return x
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
x,
Expand Down Expand Up @@ -903,14 +906,7 @@ def _forward_prefill(
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output, is_prefill=True)[0]
return attn_output

def exec_kv(
self,
Expand Down Expand Up @@ -1238,6 +1234,12 @@ def forward(
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
if not self.running_in_graph:
o_proj_input_shape = (num_actual_toks,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
Expand All @@ -1248,11 +1250,12 @@ def forward(
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
output[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
output[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:] = output_prefill

if has_decode:
if self.running_in_graph:
Expand All @@ -1269,9 +1272,32 @@ def forward(
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
o_proj_input[:num_decode_tokens] = output_decode
else:
output[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode

current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)

output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
return output_padded
80 changes: 69 additions & 11 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
Expand All @@ -160,7 +161,13 @@ def forward(
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
num_tokens = output_parallel.shape[0]
if is_force_scatter and num_tokens % self.tp_size:
output_parallel = nn.functional.pad(
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
if is_force_scatter or (not is_prefill
and output_parallel.shape[0] % self.tp_size
== 0):
output = tensor_model_parallel_reduce_scatter(output_parallel,
dim=0)
else:
Expand All @@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
Expand Down Expand Up @@ -347,13 +355,15 @@ def __init__(
reduce_results = not self.all_reduce_merge
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.shared_experts = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=reduce_results,
force_replicate=self.enable_multistream_moe,
force_replicate=self.enable_multistream_moe
or enable_shared_expert_dp,
prefix=f"{prefix}.shared_experts",
)
else:
Expand Down Expand Up @@ -447,9 +457,11 @@ def __init__(
self.kv_lora_rank = kv_lora_rank

self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace

self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
Expand All @@ -462,6 +474,7 @@ def __init__(
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
Expand Down Expand Up @@ -501,8 +514,9 @@ def __init__(
prefix=f"{prefix}.kv_b_proj")
if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and
ascend_config.torchair_graph_config.enable_multistream_moe):
and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe
or self.enable_shared_expert_dp)):
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
Expand Down Expand Up @@ -596,13 +610,27 @@ def forward(
output = output.view(-1, output_shape[-1])
return output
else:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
hidden_states_or_q_c = get_tp_group().all_gather(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of severe load imbalance between DPs, the DP domain may become stuck?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the DP domain will not block, this is all_gather of all TPs within the DP domain, and different DP domains will not affect each other.

hidden_states_or_q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)

kv_c, k_pe = kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
num_tokens = hidden_states_or_q_c.shape[0]
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
output_shape=output_shape)


class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
Expand Down Expand Up @@ -677,6 +705,8 @@ def __init__(
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

def forward(
self,
Expand Down Expand Up @@ -731,6 +761,18 @@ def forward(
# first layer.
residual *= 1. / self.routed_scaling_factor

tp_size = get_tensor_model_parallel_world_size()
if self.enable_shared_expert_dp and (
self.layer_idx == self.first_k_dense_replace
or self.layer_idx == self.layers) and tp_size > 1:
num_tokens, _ = residual.shape
if num_tokens % tp_size:
residual = nn.functional.pad(residual,
(0, 0, 0, -num_tokens % tp_size))
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
residual = chunk_residual[tp_rank]

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
Expand All @@ -756,6 +798,22 @@ def forward(
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)

# for last layer of main model and mtp layer.
if self.enable_shared_expert_dp and self.layer_idx >= (
self.layers - 1) and tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
residual = get_tp_group().all_gather(residual, 0)

attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:
num_tokens = hidden_states.shape[0]

if num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:num_tokens]
residual = residual[:num_tokens]

return hidden_states, residual


Expand Down
27 changes: 15 additions & 12 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,7 @@ def __init__(
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe and \
self.torchair_graph_enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -1391,22 +1392,24 @@ def forward(self,
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size:
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]

chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]

if self.dp_size > 1:
Expand Down Expand Up @@ -1473,7 +1476,7 @@ def forward(self,
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
] and not replace_allreduce and not self.enable_shared_expert_dp):
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
Expand All @@ -1483,7 +1486,7 @@ def forward(self,
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
elif self.dp_size > 1:
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
Expand Down
Loading