Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
"VLLM_ENABLE_MC2":
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
"VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER":
lambda: bool(int(os.getenv("VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER", '0'))),
"USING_LCCL_COM":
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
"SOC_VERSION":
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,9 @@ def __init__(
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
self.max_model_len = vllm_config.model_config.max_model_len
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs * (
dp_size if dp_size is not None else get_dp_group().world_size)

if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
Expand Down
207 changes: 207 additions & 0 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,71 @@
from vllm_ascend.ops.fused_moe import select_experts

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER: bool = envs_ascend.VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER


def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
max_row_per_ep_rank: int, num_tokens: int,
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
original_total_elements = num_tokens * top_k
device = topk_ids.device
original_dtype = topk_ids.dtype

if original_total_elements == 0:
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len, ),
expert_num,
dtype=original_dtype,
device=device)
unpad_indices = torch.full((original_total_elements, ),
-1,
dtype=torch.long,
device=device)
return topk_ids_pad, unpad_indices

experts_per_ep_rank_val = expert_num // ep_size
if experts_per_ep_rank_val == 0:
raise ValueError(
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
"Ensure expert_num >= ep_size.")

assigned_ep_rank = (topk_ids.float() /
experts_per_ep_rank_val).to(original_dtype)
indices_arange = torch.arange(topk_ids.shape[0], device=device)

is_new_segment = torch.cat((torch.tensor([True], device=device),
assigned_ep_rank[1:] != assigned_ep_rank[:-1]))
temp_start_markers = torch.full_like(indices_arange,
-1,
dtype=indices_arange.dtype)
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
indices_in_rec_cond_list_for_all = cumsum_kept - 1
unpad_indices = torch.where(
is_kept_mask, indices_in_rec_cond_list_for_all,
torch.tensor(-1, device=device, dtype=torch.long))
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len, ),
expert_num,
dtype=original_dtype,
device=device)
if topk_ids.shape[0] > 0:
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
temp_pad_buffer = torch.full((output_len + 1, ),
expert_num,
dtype=original_dtype,
device=device)
output_len_tensor = torch.tensor(output_len,
dtype=torch.long,
device=device)
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
output_len_tensor)
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
topk_ids_pad = temp_pad_buffer[:output_len]
return topk_ids_pad, unpad_indices


def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
Expand Down Expand Up @@ -325,6 +390,134 @@ def fused_experts_with_all2all(
return final_hidden_states


def fused_experts_with_all2all_with_fixed_buffer(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
max_model_len: int,
global_batch_size: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

num_tokens, _ = hidden_states.shape
device = hidden_states.device

global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
device=device).view(top_k,
-1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)

max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
max_model_len // ep_group.world_size +
1) * top_k * 2
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
expanded_expert_idx, global_num_experts, ep_group.world_size,
max_row_per_ep_rank, num_tokens, top_k)
hidden_states_pad_idx = torch.zeros(
expert_idx_buffer_scatter.shape,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
non_pad_len = torch.sum(
(expert_idx_buffer_scatter != global_num_experts).to(torch.int32))
hidden_states_pad_idx[
expert_idx_buffer_scatter != global_num_experts] = torch.arange(
non_pad_len,
dtype=expert_idx_buffer_scatter.dtype,
device=hidden_states.device)

hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
expert_idx_buffer_gather = torch.empty_like(
expert_idx_buffer_scatter,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
hidden_states_buffer_gather = torch.empty_like(
hidden_states_buffer_scatter,
dtype=hidden_states_buffer_scatter.dtype,
device=hidden_states_buffer_scatter.device)
dist.all_to_all_single(expert_idx_buffer_gather,
expert_idx_buffer_scatter,
group=ep_group.device_group)
dist.all_to_all_single(hidden_states_buffer_gather,
hidden_states_buffer_scatter,
group=ep_group.device_group)
mask = expert_idx_buffer_gather != global_num_experts
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
global_num_experts // ep_group.world_size)
hidden_states = hidden_states_buffer_gather[mask]
idx_type = local_expert_idx.dtype
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0

hidden_states_wrapper = [hidden_states]
del hidden_states

hidden_states = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
w2_scale,
expert_tokens,
group_list_type=group_list_type)

resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
hidden_states = hidden_states[resorted_idx]
hidden_states_scatter = torch.zeros(
(mask.shape[0], hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states_scatter[mask] = hidden_states
hidden_states_gatter = torch.empty_like(
hidden_states_scatter,
dtype=hidden_states_scatter.dtype,
device=hidden_states_scatter.device)
dist.all_to_all_single(hidden_states_gatter,
hidden_states_scatter,
group=ep_group.device_group)
hidden_states_gatter = hidden_states_gatter[
expert_idx_buffer_scatter != global_num_experts]
if hidden_states_gatter.shape[0] != row_idx_len:
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states[unpad_indices != -1] = hidden_states_gatter
else:
hidden_states = hidden_states_gatter
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)

if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
Expand Down Expand Up @@ -644,6 +837,20 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER and expert_map is not None:
return fused_experts_with_all2all_with_fixed_buffer(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
max_model_len=layer.max_model_len,
global_batch_size=layer.global_batch_size,
expert_map=expert_map,
ep_group=self.ep_group)
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
Expand Down
Loading