Skip to content

Commit 76c143f

Browse files
committed
feat: support data parallel for deepseek
Signed-off-by: boying <897013703@qq.com>
1 parent 52e0e99 commit 76c143f

File tree

6 files changed

+205
-68
lines changed

6 files changed

+205
-68
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class AscendMLAMetadata:
117117
# For logging.
118118
num_input_tokens: int = 0 # Number of tokens including padding.
119119

120+
with_prefill_across_dp: bool = False
121+
120122
# The dimension of the attention heads
121123
head_dim: Optional[int] = None
122124
attn_mask: torch.Tensor = None
@@ -280,13 +282,16 @@ def build_dummy(self, num_reqs: int,
280282
decode=decode_metadata,
281283
)
282284

283-
def build(self,
284-
num_reqs: int,
285-
num_actual_tokens: int,
286-
max_query_len: int,
287-
common_attn_metadata: CommonAttentionMetadata,
288-
common_prefix_len: Optional[int] = None,
289-
graph_pad_size: int = -1) -> AscendMLAMetadata:
285+
def build(
286+
self,
287+
num_reqs: int,
288+
num_actual_tokens: int,
289+
max_query_len: int,
290+
common_attn_metadata: CommonAttentionMetadata,
291+
common_prefix_len: Optional[int] = None,
292+
graph_pad_size: int = -1,
293+
with_prefill_across_dp: bool = False,
294+
) -> AscendMLAMetadata:
290295
assert self._num_decodes + self._num_prefills == num_reqs
291296

292297
# Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -388,6 +393,7 @@ def build(self,
388393
query_start_loc=query_start_loc,
389394
block_tables=block_table,
390395
seq_lens=seq_lens,
396+
with_prefill_across_dp=with_prefill_across_dp,
391397
)
392398

393399

@@ -621,7 +627,7 @@ def exec_kv(
621627
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
622628
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
623629
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
624-
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
630+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
625631
kv,
626632
self.kv_a_layernorm.weight,
627633
cos,
@@ -643,7 +649,7 @@ def rope_single(
643649
B, N, D = x.shape
644650
S = 1
645651
x = x.view(B, N, S, D)
646-
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
652+
x = torch_npu.npu_interleave_rope(x, cos, sin)
647653
return x.view(B, N, D)
648654

649655
def _forward_decode(

vllm_ascend/models/deepseek_v2.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ def __init__(
212212
self.tp_group = get_tp_group().device_group
213213
self.tp_rank = get_tp_group().rank_in_group
214214

215+
self.params_dtype = torch.get_default_dtype()
216+
217+
self.enable_graph_mode = False
218+
additional_config = get_current_vllm_config().additional_config
219+
if additional_config:
220+
self.enable_graph_mode = additional_config.get(
221+
"enable_graph_mode", False)
222+
215223
def forward(
216224
self,
217225
hidden_states: torch.Tensor,
@@ -228,52 +236,65 @@ def forward(
228236
else:
229237
is_prefill = attn_metadata.num_prefills > 0
230238
enable_force_load_balance = False
231-
num_tokens, hidden_dim = hidden_states.shape
239+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
240+
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
241+
242+
num_tokens, hidden_size = hidden_states.shape
232243

233244
if self.n_shared_experts is not None:
234245
shared_output = self.shared_experts(hidden_states)
235246

236247
if self.tp_size > 1:
237-
# pass
238-
num_tokens, hidden_size = hidden_states.shape
239-
if num_tokens < self.tp_size:
240-
target_size = self.tp_size
241-
new_hidden_states = torch.empty([target_size, hidden_size],
242-
dtype=hidden_states.dtype,
243-
device=hidden_states.device)
244-
new_hidden_states[:num_tokens] = hidden_states
245-
hidden_states = new_hidden_states
246-
chunk_hidden_states = torch.tensor_split(hidden_states,
247-
self.tp_size,
248-
dim=0)
249-
local_hidden_states = chunk_hidden_states[self.tp_rank]
250-
else:
251-
local_hidden_states = hidden_states
248+
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
249+
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
250+
hidden_states = chunks[self.tp_rank]
251+
elif not self.enable_graph_mode:
252+
num_padding_tokens = (self.tp_size -
253+
num_tokens % self.tp_size) % self.tp_size
254+
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
255+
if num_padding_tokens > 0:
256+
hidden_states = nn.functional.pad(
257+
hidden_states, (0, 0, 0, num_padding_tokens))
258+
chunk_hidden_states = torch.tensor_split(hidden_states,
259+
self.tp_size,
260+
dim=0)
261+
hidden_states = chunk_hidden_states[self.tp_rank]
252262

253263
# router_logits: (num_tokens, n_experts)
254-
router_logits, _ = self.gate(local_hidden_states)
264+
router_logits, _ = self.gate(hidden_states)
255265

256-
router_hidden_states = self.experts(
257-
hidden_states=local_hidden_states,
266+
hidden_states = self.experts(
267+
hidden_states=hidden_states,
258268
router_logits=router_logits,
259269
is_prefill=is_prefill,
260270
top_k=CustomDeepseekV2MoE.top_k,
261271
enable_force_load_balance=enable_force_load_balance,
262272
) * self.routed_scaling_factor
263273

264274
if self.tp_size > 1:
265-
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
266-
self.tp_group)
267-
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
268-
if num_tokens < self.tp_size:
269-
final_hidden_states = final_hidden_states[:num_tokens]
270-
else:
271-
final_hidden_states = router_hidden_states
275+
if self.enable_graph_mode:
276+
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
277+
final_hidden_states = torch.zeros(
278+
[num_tokens, hidden_size],
279+
dtype=self.params_dtype,
280+
device="npu")
281+
dist.all_gather_into_tensor(final_hidden_states,
282+
hidden_states, self.tp_group)
283+
hidden_states = final_hidden_states
284+
else:
285+
hidden_states = tensor_model_parallel_all_reduce(
286+
hidden_states)
287+
else:
288+
dist.all_gather(list(chunk_hidden_states), hidden_states,
289+
self.tp_group)
290+
hidden_states = torch.cat(chunk_hidden_states, dim=0)
291+
if num_padding_tokens > 0:
292+
hidden_states = hidden_states[:-num_padding_tokens]
272293

273294
if shared_output is not None:
274-
final_hidden_states = final_hidden_states + shared_output
295+
hidden_states = hidden_states + shared_output
275296

276-
return final_hidden_states.view(num_tokens, hidden_dim)
297+
return hidden_states.view(num_tokens, hidden_size)
277298

278299

279300
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

vllm_ascend/ops/fused_moe.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_tensor_model_parallel_world_size,
2626
tensor_model_parallel_all_reduce)
2727
from vllm.distributed.parallel_state import get_dp_group
28+
from vllm.forward_context import get_forward_context
2829
from vllm.model_executor.layers.fused_moe.layer import (
2930
FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod,
3031
determine_expert_map)
@@ -588,6 +589,12 @@ def __init__(self, moe: MoEConfig = None):
588589
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
589590
self.local_batch_size = self.global_batch_size // self.ep_size
590591

592+
self.enable_graph_mode = False
593+
additional_config = get_current_vllm_config().additional_config
594+
if additional_config:
595+
self.enable_graph_mode = additional_config.get(
596+
"enable_graph_mode", False)
597+
591598
try:
592599
device_group = ep_group.device_group
593600
# TODO: Try local_rank = ep_group.rank_in_group
@@ -665,7 +672,7 @@ def apply(
665672
top_k=top_k,
666673
expert_map=expert_map,
667674
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
668-
elif get_ep_group().world_size == 1:
675+
elif self.enable_graph_mode or get_ep_group().world_size == 1:
669676
return fused_experts(hidden_states=x,
670677
w1=layer.w13_weight,
671678
w2=layer.w2_weight,
@@ -771,6 +778,13 @@ def __init__(
771778

772779
self.local_num_experts, self.expert_map = (self.global_num_experts,
773780
None)
781+
782+
self.enable_graph_mode = False
783+
additional_config = get_current_vllm_config().additional_config
784+
if additional_config:
785+
self.enable_graph_mode = additional_config.get(
786+
"enable_graph_mode", False)
787+
774788
if self.scoring_func != "softmax" and not self.use_grouped_topk:
775789
raise ValueError("Only softmax scoring function is supported for "
776790
"non-grouped topk.")
@@ -808,8 +822,15 @@ def __init__(
808822
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
809823
moe_quant_params["intermediate_size_full"] = intermediate_size
810824

825+
self.ep_group = get_ep_group()
811826
self.quant_method.create_weights(layer=self, **moe_quant_params)
812827

828+
self.enable_graph_mode = False
829+
additional_config = get_current_vllm_config().additional_config
830+
if additional_config:
831+
self.enable_graph_mode = additional_config.get(
832+
"enable_graph_mode", False)
833+
813834
def forward(self,
814835
hidden_states: torch.Tensor,
815836
router_logits: torch.Tensor,
@@ -823,11 +844,32 @@ def forward(self,
823844
else:
824845
real_top_k = self.top_k
825846

826-
if VLLM_ENABLE_MC2 and not is_prefill:
827-
...
847+
# MC2 ag/rs broadcast/all_reduce
848+
# prefill_req x x √
849+
# decode_req √ x √
850+
# graph_mode √ √ x
851+
if self.dp_size > 1:
852+
if VLLM_ENABLE_MC2 and not is_prefill:
853+
...
854+
elif self.enable_graph_mode:
855+
if USING_LCCL_COM: # type: ignore
856+
hidden_states = get_dp_group().all_gather(
857+
hidden_states, 0, False)
858+
router_logits = get_dp_group().all_gather(
859+
router_logits, 0, False)
860+
elif self.enable_graph_mode and not is_prefill:
861+
hidden_states = get_dp_group().all_gather(hidden_states, 0)
862+
router_logits = get_dp_group().all_gather(router_logits, 0)
863+
else:
864+
cu_tokens_across_dp_cpu = get_forward_context(
865+
).dp_metadata.cu_tokens_across_dp_cpu
866+
hidden_states = self.naive_multicast(
867+
hidden_states, cu_tokens_across_dp_cpu)
868+
router_logits = self.naive_multicast(
869+
router_logits, cu_tokens_across_dp_cpu)
828870

829871
# Matrix multiply.
830-
final_hidden_states = self.quant_method.apply(
872+
hidden_states = self.quant_method.apply(
831873
layer=self,
832874
x=hidden_states,
833875
router_logits=router_logits,
@@ -844,11 +886,31 @@ def forward(self,
844886
is_prefill=is_prefill,
845887
enable_force_load_balance=enable_force_load_balance)
846888

847-
if VLLM_ENABLE_MC2 and not is_prefill:
848-
...
889+
if self.dp_size > 1:
890+
if VLLM_ENABLE_MC2 and not is_prefill:
891+
...
892+
elif self.enable_graph_mode:
893+
if USING_LCCL_COM: # type: ignore
894+
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
895+
hidden_states,
896+
"sum",
897+
scatter_dim=0,
898+
group=get_dp_group().device_group)
899+
elif self.enable_graph_mode and not is_prefill:
900+
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
901+
hidden_states,
902+
"sum",
903+
scatter_dim=0,
904+
group=get_dp_group().device_group)
905+
else:
906+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
907+
self.dp_rank - 1]
908+
#print(cu_tokens_across_dp_cpu)
909+
end = cu_tokens_across_dp_cpu[self.dp_rank]
910+
hidden_states = get_dp_group().all_reduce(hidden_states)
911+
hidden_states = hidden_states[start:end, :]
849912

850913
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
851-
final_hidden_states = tensor_model_parallel_all_reduce(
852-
final_hidden_states)
914+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
853915

854-
return final_hidden_states
916+
return hidden_states

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23+
from vllm.config import get_current_vllm_config
2324
from vllm.distributed import GroupCoordinator
2425

2526
import vllm_ascend.envs as envs_ascend
@@ -508,6 +509,12 @@ def __init__(self):
508509

509510
self.ep_group = get_ep_group()
510511

512+
self.enable_graph_mode = False
513+
additional_config = get_current_vllm_config().additional_config
514+
if additional_config:
515+
self.enable_graph_mode = additional_config.get(
516+
"enable_graph_mode", False)
517+
511518
try:
512519
device_group = self.ep_group.device_group
513520
# TODO: Try local_rank = ep_group.rank_in_group
@@ -629,7 +636,7 @@ def apply(
629636
top_k=top_k,
630637
expert_map=expert_map,
631638
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
632-
elif self.ep_group.world_size == 1:
639+
elif self.enable_graph_mode or self.ep_group.world_size == 1:
633640
return fused_experts(hidden_states=x,
634641
w1=layer.w13_weight,
635642
w1_scale=layer.w13_weight_scale,

0 commit comments

Comments
 (0)