|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/Parallel.h> |
| 3 | +#include <torch/all.h> |
| 4 | + |
| 5 | +// _dyn_quant_matmul_4bit is only available on AArch64. |
| 6 | +#if defined(__aarch64__) |
| 7 | + #include <ATen/ops/_dyn_quant_matmul_4bit.h> |
| 8 | +#endif |
| 9 | + |
| 10 | +inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w, |
| 11 | + int64_t group_size_eff, int64_t in_features, |
| 12 | + int64_t out_features) { |
| 13 | +#if defined(__aarch64__) |
| 14 | + return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff, |
| 15 | + in_features, out_features); |
| 16 | +#else |
| 17 | + TORCH_CHECK(false, |
| 18 | + "dynamic 4-bit int MoE path requires AArch64 (ARM64); " |
| 19 | + "_dyn_quant_matmul_4bit is unavailable on this architecture"); |
| 20 | + return {}; |
| 21 | +#endif |
| 22 | +} |
| 23 | + |
| 24 | +enum ActivationKind : int64_t { |
| 25 | + SwiGLU_Gu = 0, // act = SiLU(g) * u |
| 26 | + SwiGLUOAI = 1, // act = SiLU(u) * g |
| 27 | + SiLU = 2 // SiLU |
| 28 | +}; |
| 29 | + |
| 30 | +torch::Tensor dynamic_4bit_int_moe_cpu( |
| 31 | + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, |
| 32 | + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, |
| 33 | + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, |
| 34 | + int64_t activation_kind) { |
| 35 | + TORCH_CHECK(x.dim() == 2, "x must be 2D"); |
| 36 | + TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2, |
| 37 | + "topk tensors must be [T, K]"); |
| 38 | + TORCH_CHECK( |
| 39 | + w13_packed.size(0) == w2_packed.size(0), |
| 40 | + "w13_packed and w2_packed must have same number of experts in dim 0"); |
| 41 | + TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I"); |
| 42 | + |
| 43 | + const int64_t T = x.size(0); |
| 44 | + const int64_t K = topk_ids.size(1); |
| 45 | + const int64_t E = w13_packed.size(0); |
| 46 | + const int64_t N = T * K; |
| 47 | + |
| 48 | + auto x_c = x.contiguous(); |
| 49 | + auto ids_c = topk_ids.contiguous(); |
| 50 | + auto gates_c = topk_weights.to(at::kFloat).contiguous(); |
| 51 | + |
| 52 | + // bucketing tokens -> experts |
| 53 | + c10::SmallVector<int64_t, 64> counts( |
| 54 | + E, 0); // Small vector uses stack allocation |
| 55 | + { |
| 56 | + const auto* ids_ptr = ids_c.data_ptr<int64_t>(); |
| 57 | + for (int64_t i = 0; i < N; ++i) { |
| 58 | + const int64_t e_id = ids_ptr[i]; |
| 59 | + TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range"); |
| 60 | + counts[e_id]++; |
| 61 | + } |
| 62 | + } |
| 63 | + c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 ) |
| 64 | + for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e]; |
| 65 | + |
| 66 | + auto expert_tokens = at::empty({offsets[E]}, ids_c.options()); |
| 67 | + auto expert_gates = at::empty({offsets[E]}, gates_c.options()); |
| 68 | + { |
| 69 | + c10::SmallVector<int64_t, 64> cursor(E, 0); |
| 70 | + const auto* ids_ptr = ids_c.data_ptr<int64_t>(); |
| 71 | + const auto* gts_ptr = gates_c.data_ptr<float>(); |
| 72 | + auto* tok_ptr = expert_tokens.data_ptr<int64_t>(); |
| 73 | + auto* gate_ptr = expert_gates.data_ptr<float>(); |
| 74 | + |
| 75 | + for (int64_t t = 0; t < T; ++t) { |
| 76 | + const int64_t base = t * K; |
| 77 | + for (int64_t k = 0; k < K; ++k) { |
| 78 | + const int64_t idx = base + k; |
| 79 | + const int64_t e = ids_ptr[idx]; |
| 80 | + const int64_t p = offsets[e] + (cursor[e]++); |
| 81 | + tok_ptr[p] = t; |
| 82 | + gate_ptr[p] = gts_ptr[idx]; |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + const int64_t g_eff_13 = (group_size != -1) ? group_size : H; |
| 88 | + const int64_t g_eff_2 = (group_size != -1) ? group_size : I; |
| 89 | + |
| 90 | + // Per-expert outputs filled in parallel |
| 91 | + std::vector<torch::Tensor> y_list(E); |
| 92 | + y_list.resize(E); |
| 93 | + |
| 94 | + at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { |
| 95 | + for (int64_t e = e_begin; e < e_end; ++e) { |
| 96 | + const int64_t te = counts[e]; |
| 97 | + if (te == 0) { |
| 98 | + y_list[e] = at::empty({0, H}, x_c.options()); |
| 99 | + continue; |
| 100 | + } |
| 101 | + |
| 102 | + const int64_t start = offsets[e]; |
| 103 | + |
| 104 | + auto sel_tokens = |
| 105 | + expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); |
| 106 | + auto gates_e = |
| 107 | + expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); |
| 108 | + |
| 109 | + auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); |
| 110 | + |
| 111 | + if (apply_router_weight_on_input) { |
| 112 | + x_e = x_e.mul(gates_e.unsqueeze(1)); |
| 113 | + } |
| 114 | + |
| 115 | + auto w13_e = w13_packed.select(/*dim=*/0, e); |
| 116 | + auto w2_e = w2_packed.select(/*dim=*/0, e); |
| 117 | + |
| 118 | + // W13 |
| 119 | + auto y13 = |
| 120 | + mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2); |
| 121 | + |
| 122 | + auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I); |
| 123 | + auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I); |
| 124 | + |
| 125 | + torch::Tensor act; |
| 126 | + if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI |
| 127 | + constexpr double kAlpha = 1.702; // GPT-OSS default |
| 128 | + constexpr double kLimit = 7.0; // GPT-OSS default |
| 129 | + auto gate_c = at::clamp_max(g_part, kLimit); |
| 130 | + auto up_c = at::clamp(u_part, -kLimit, kLimit); |
| 131 | + auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha))); |
| 132 | + act = up_c.add(1.0).mul(glu); |
| 133 | + } else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul() |
| 134 | + act = at::silu(g_part).mul(u_part); |
| 135 | + } |
| 136 | + |
| 137 | + // W2 |
| 138 | + auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); |
| 139 | + |
| 140 | + if (!apply_router_weight_on_input) { |
| 141 | + y = y.mul(gates_e.unsqueeze(1)); |
| 142 | + } |
| 143 | + |
| 144 | + // Store per-expert result |
| 145 | + y_list[e] = y; |
| 146 | + } |
| 147 | + }); |
| 148 | + |
| 149 | + // Concatenate all expert outputs to match expert_tokens order |
| 150 | + auto Y_all = at::cat(y_list, /*dim=*/0); |
| 151 | + auto out = at::zeros({T, H}, x.options()); |
| 152 | + out = |
| 153 | + at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all); |
| 154 | + |
| 155 | + return out; |
| 156 | +} |
0 commit comments