Skip to content

Commit 724cb7f

Browse files
nikhil-armgjc0824
authored andcommitted
[fix]: add Arm 4bit fused moe support (#23809)
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: gaojc <1055866782@qq.com>
1 parent 5be6509 commit 724cb7f

File tree

7 files changed

+488
-11
lines changed

7 files changed

+488
-11
lines changed

cmake/cpu_extension.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
258258
"csrc/cpu/layernorm.cpp"
259259
"csrc/cpu/mla_decode.cpp"
260260
"csrc/cpu/pos_encoding.cpp"
261-
"csrc/cpu/torch_bindings.cpp")
261+
"csrc/cpu/torch_bindings.cpp"
262+
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
262263

263264
if (AVX512_FOUND AND NOT AVX512_DISABLED)
264265
set(VLLM_EXT_SRC

csrc/cpu/torch_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8888
" int tp_rank, int blocksparse_local_blocks,"
8989
" int blocksparse_vert_stride, int blocksparse_block_size,"
9090
" int blocksparse_head_sliding_step) -> ()");
91+
9192
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
9293

94+
ops.def(
95+
"dynamic_4bit_int_moe("
96+
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
97+
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
98+
"int group_size, bool apply_router_weight_on_input, int activation_kind"
99+
") -> Tensor");
100+
101+
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
102+
93103
// PagedAttention V2.
94104
ops.def(
95105
"paged_attention_v2("
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Parallel.h>
3+
#include <torch/all.h>
4+
5+
// _dyn_quant_matmul_4bit is only available on AArch64.
6+
#if defined(__aarch64__)
7+
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
8+
#endif
9+
10+
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
11+
int64_t group_size_eff, int64_t in_features,
12+
int64_t out_features) {
13+
#if defined(__aarch64__)
14+
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
15+
in_features, out_features);
16+
#else
17+
TORCH_CHECK(false,
18+
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
19+
"_dyn_quant_matmul_4bit is unavailable on this architecture");
20+
return {};
21+
#endif
22+
}
23+
24+
enum ActivationKind : int64_t {
25+
SwiGLU_Gu = 0, // act = SiLU(g) * u
26+
SwiGLUOAI = 1, // act = SiLU(u) * g
27+
SiLU = 2 // SiLU
28+
};
29+
30+
torch::Tensor dynamic_4bit_int_moe_cpu(
31+
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
32+
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
33+
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
34+
int64_t activation_kind) {
35+
TORCH_CHECK(x.dim() == 2, "x must be 2D");
36+
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
37+
"topk tensors must be [T, K]");
38+
TORCH_CHECK(
39+
w13_packed.size(0) == w2_packed.size(0),
40+
"w13_packed and w2_packed must have same number of experts in dim 0");
41+
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
42+
43+
const int64_t T = x.size(0);
44+
const int64_t K = topk_ids.size(1);
45+
const int64_t E = w13_packed.size(0);
46+
const int64_t N = T * K;
47+
48+
auto x_c = x.contiguous();
49+
auto ids_c = topk_ids.contiguous();
50+
auto gates_c = topk_weights.to(at::kFloat).contiguous();
51+
52+
// bucketing tokens -> experts
53+
c10::SmallVector<int64_t, 64> counts(
54+
E, 0); // Small vector uses stack allocation
55+
{
56+
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
57+
for (int64_t i = 0; i < N; ++i) {
58+
const int64_t e_id = ids_ptr[i];
59+
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
60+
counts[e_id]++;
61+
}
62+
}
63+
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
64+
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
65+
66+
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
67+
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
68+
{
69+
c10::SmallVector<int64_t, 64> cursor(E, 0);
70+
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
71+
const auto* gts_ptr = gates_c.data_ptr<float>();
72+
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
73+
auto* gate_ptr = expert_gates.data_ptr<float>();
74+
75+
for (int64_t t = 0; t < T; ++t) {
76+
const int64_t base = t * K;
77+
for (int64_t k = 0; k < K; ++k) {
78+
const int64_t idx = base + k;
79+
const int64_t e = ids_ptr[idx];
80+
const int64_t p = offsets[e] + (cursor[e]++);
81+
tok_ptr[p] = t;
82+
gate_ptr[p] = gts_ptr[idx];
83+
}
84+
}
85+
}
86+
87+
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
88+
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
89+
90+
// Per-expert outputs filled in parallel
91+
std::vector<torch::Tensor> y_list(E);
92+
y_list.resize(E);
93+
94+
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
95+
for (int64_t e = e_begin; e < e_end; ++e) {
96+
const int64_t te = counts[e];
97+
if (te == 0) {
98+
y_list[e] = at::empty({0, H}, x_c.options());
99+
continue;
100+
}
101+
102+
const int64_t start = offsets[e];
103+
104+
auto sel_tokens =
105+
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
106+
auto gates_e =
107+
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
108+
109+
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
110+
111+
if (apply_router_weight_on_input) {
112+
x_e = x_e.mul(gates_e.unsqueeze(1));
113+
}
114+
115+
auto w13_e = w13_packed.select(/*dim=*/0, e);
116+
auto w2_e = w2_packed.select(/*dim=*/0, e);
117+
118+
// W13
119+
auto y13 =
120+
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
121+
122+
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
123+
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
124+
125+
torch::Tensor act;
126+
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
127+
constexpr double kAlpha = 1.702; // GPT-OSS default
128+
constexpr double kLimit = 7.0; // GPT-OSS default
129+
auto gate_c = at::clamp_max(g_part, kLimit);
130+
auto up_c = at::clamp(u_part, -kLimit, kLimit);
131+
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
132+
act = up_c.add(1.0).mul(glu);
133+
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
134+
act = at::silu(g_part).mul(u_part);
135+
}
136+
137+
// W2
138+
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
139+
140+
if (!apply_router_weight_on_input) {
141+
y = y.mul(gates_e.unsqueeze(1));
142+
}
143+
144+
// Store per-expert result
145+
y_list[e] = y;
146+
}
147+
});
148+
149+
// Concatenate all expert outputs to match expert_tokens order
150+
auto Y_all = at::cat(y_list, /*dim=*/0);
151+
auto out = at::zeros({T, H}, x.options());
152+
out =
153+
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
154+
155+
return out;
156+
}

csrc/ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
328328
const std::optional<torch::Tensor>& has_initial_state,
329329
const torch::Tensor& ssm_states, int64_t pad_slot_id);
330330

331+
torch::Tensor dynamic_4bit_int_moe_cpu(
332+
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
333+
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
334+
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
335+
int64_t activation_kind);
336+
331337
using fptr_t = int64_t;
332338
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
333339
torch::Tensor& rank_data, int64_t rank,

vllm/model_executor/layers/fused_moe/cpu_fused_moe.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,16 @@ def select_experts(
9898
e_score_correction_bias=e_score_correction_bias)
9999
elif custom_routing_function is None:
100100
assert scoring_func == "softmax"
101-
topk_weights = torch.nn.functional.softmax(router_logits,
102-
dim=1,
103-
dtype=torch.float32)
104-
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
101+
topk_logit_vals, topk_idx = torch.topk(router_logits,
102+
k=top_k,
103+
dim=-1,
104+
sorted=False)
105105
if renormalize:
106-
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
107-
return topk_weights, topk_ids.to(torch.int32)
106+
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
107+
else:
108+
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
109+
topk_vals = (topk_logit_vals - logZ).exp()
110+
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
108111
else:
109112
return custom_routing_function(hidden_states=hidden_states,
110113
gating_output=router_logits,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def eplb_map_to_physical_and_record(
6969
if is_rocm_aiter_moe_enabled():
7070
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
7171
rocm_aiter_grouped_topk as grouped_topk)
72-
elif current_platform.is_cpu():
73-
pass
7472
else:
7573
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
7674
if current_platform.is_tpu():

0 commit comments

Comments
 (0)