Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
include_finished_set, log_stats)
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []
self.include_finished_set = include_finished_set

if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
Expand Down
43 changes: 43 additions & 0 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.

#ep=1 etp=tp
if self.experts.ep_group.world_size == 1:
return self.forward_etp(hidden_states)

if attn_metadata is None:
# for profile run
is_prefill = True
Expand Down Expand Up @@ -270,6 +275,44 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

return final_hidden_states.view(num_tokens, hidden_dim)

def forward_etp(self,
hidden_states: torch.Tensor,
is_prefill: bool = False) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill and self.tp_size > 1:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank_in_group]

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor

if self.tp_size > 1:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros([num_tokens, hidden_dim],
dtype=self.params_dtype,
device="npu")
dist.all_gather_into_tensor(final_hidden_states, hidden_states,
self.tp_group)
hidden_states = final_hidden_states
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)

if self.n_shared_experts is not None:
hidden_states = hidden_states + shared_output

return hidden_states.view(num_tokens, hidden_dim)


class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

Expand Down
108 changes: 100 additions & 8 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py

import os
from typing import Callable, Optional

import torch
Expand All @@ -25,6 +26,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod,
determine_expert_map)
Expand Down Expand Up @@ -329,6 +331,7 @@ def fused_experts(
num_experts = w1.shape[0]
dtype = hidden_states.dtype
device = hidden_states.device
topk_weights = topk_weights.to(dtype)
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfloat16 are supported"

Expand Down Expand Up @@ -761,16 +764,17 @@ def __init__(
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group

else:
# Adjust TP size for DP attention
# haven't test its functionality yet, may remove in the future

self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
self.moe_parallel_config.ep_rank = 0
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
self.moe_parallel_config.ep_size = 1

self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)

self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
Expand Down Expand Up @@ -808,6 +812,7 @@ def __init__(
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size

self.ep_group = get_ep_group()
self.quant_method.create_weights(layer=self, **moe_quant_params)

def forward(self,
Expand All @@ -817,6 +822,10 @@ def forward(self,
enable_force_load_balance: bool = False,
top_k=None):
assert self.quant_method is not None
#ep=1 etp=tp
if self.ep_group.world_size == 1:
return self.forward_etp(hidden_states, router_logits, is_prefill,
top_k)

if top_k:
real_top_k = top_k
Expand Down Expand Up @@ -852,3 +861,86 @@ def forward(self,
final_hidden_states)

return final_hidden_states

def forward_etp(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
top_k=None):
assert self.quant_method is not None

if top_k:
real_top_k = top_k
else:
real_top_k = self.top_k

# MC2 ag/rs broadcast/all_reduce
# prefill_req x x √
# decode_req √ x √
# graph_mode √ √ x
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif int(os.environ.get("USING_LCCL_COM",
'0')) == 1: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
elif self.enable_graph_mode and not is_prefill:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
else:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill)

if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif int(os.environ.get("USING_LCCL_COM",
'0')) == 1: # type: ignore
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
final_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
elif self.enable_graph_mode and not is_prefill:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
final_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
else:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
all_hidden_states = get_dp_group().all_reduce(
final_hidden_states)
final_hidden_states = all_hidden_states[start:end, :]

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states
1 change: 1 addition & 0 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def fused_experts(hidden_states: torch.Tensor,
num_experts = w1.shape[0]
dtype = hidden_states.dtype
device = hidden_states.device
topk_weights = topk_weights.to(dtype)

if expert_map is not None:
# Generate token indices and flatten
Expand Down