Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/ut/quantization/test_w8a8_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from unittest.mock import MagicMock, patch

import torch

from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all


class TestAscendW8A8FusedMoEMethod(TestBase):

def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)

@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_fused_experts_with_all2all(self, mock_moe_init_routing,
mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu,
mock_grouped_matmul,
mock_moe_re_routing,
mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder

result = fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))
129 changes: 81 additions & 48 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,29 @@ def fused_experts_with_mc2(
return hidden_states, shared_output


def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
num_tokens, _ = hidden_states.shape
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)

expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
1, 0).contiguous().view(-1))
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
global_expert_tokens = global_expert_tokens.to(torch.int32)
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales


# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
Expand All @@ -358,50 +381,54 @@ def fused_experts_with_all2all(

num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device

if expert_map is not None:
global_num_experts = len(expert_map) + global_redundant_expert_num
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)

global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)

gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()

expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)

sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)

hidden_states = hidden_states[sorted_idx]
group_list_type = 0
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
hidden_states,
expert_idx=topk_ids.to(torch.int32),
active_num=0,
expert_capacity=0,
expert_num=global_num_experts,
drop_pad_mode=0,
expert_tokens_num_mode=2,
expert_tokens_before_capacity_flag=False,
quant_mode=1,
)
else:
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
hidden_states, top_k, topk_ids, global_num_experts)

gather_sizes = global_expert_tokens.new_empty(
global_expert_tokens.shape[0])
dist.all_to_all_single(gather_sizes, global_expert_tokens)

token_counts_combined = torch.stack(
[gather_sizes, global_expert_tokens], dim=0)
token_counts_combined = token_counts_combined.view(
2, ep_group.world_size, -1).sum(dim=2)
token_counts_combined_cpu = token_counts_combined.to(
torch.device("cpu"), non_blocking=True).numpy()
all_tokens = gather_sizes.sum()

gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
quantized_tokens.shape[1])
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
gather_size_list = token_counts_combined_cpu[1]
scatter_size_list = token_counts_combined_cpu[0]

dist.all_to_all_single(gathered_tokens, quantized_tokens,
scatter_size_list, gather_size_list)
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
gather_size_list)

hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
gathered_tokens,
gather_sizes.view(ep_group.world_size, -1),
per_token_scales=dynamic_scale)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
Expand All @@ -419,6 +446,7 @@ def fused_experts_with_all2all(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
dynamic_scale = None

# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(
Expand All @@ -428,14 +456,19 @@ def fused_experts_with_all2all(
w2,
w2_scale,
expert_tokens, #16
dynamic_scale=dynamic_scale,
group_list_type=group_list_type)

if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)
reordered_outputs = torch.index_select(
hidden_states,
dim=0,
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))

hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
dist.all_to_all_single(hidden_states, reordered_outputs,
gather_size_list, scatter_size_list)

final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
Expand All @@ -444,8 +477,8 @@ def fused_experts_with_all2all(
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
export_for_source_row=None,
drop_pad_mode=2)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
Expand Down
Loading