Skip to content

Commit 31259d2

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (vllm-project#24588)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent 39f7d28 commit 31259d2

File tree

6 files changed

+274
-75
lines changed

6 files changed

+274
-75
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,11 @@ def use_int4_w4a16(self) -> bool:
288288

289289
@property
290290
def use_mxfp4_w4a4(self) -> bool:
291-
return self.quant_dtype == "mxfp4"
291+
return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")
292+
293+
@property
294+
def use_mxfp4_w4a16(self) -> bool:
295+
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")
292296

293297
@property
294298
def use_nvfp4_w4a4(self) -> bool:
@@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config(
453457
)
454458

455459

460+
def mxfp4_w4a16_moe_quant_config(
461+
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
462+
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
463+
w1_bias: Optional[torch.Tensor] = None,
464+
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
465+
"""
466+
Construct a quant config for unquantized activations and mxfp4 weights.
467+
"""
468+
return FusedMoEQuantConfig(
469+
_a1=FusedMoEQuantDesc(),
470+
_a2=FusedMoEQuantDesc(),
471+
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
472+
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
473+
)
474+
475+
456476
def mxfp4_w4a4_moe_quant_config(
457477
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
458478
w2_scale: Union[torch.Tensor, "PrecisionConfig"],

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,31 @@
1111
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
1212
from vllm.model_executor.layers.fused_moe.utils import (
1313
moe_kernel_quantize_input)
14+
from vllm.utils import round_up
1415

1516

1617
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
1718
"""
1819
Prepare/Finalize using DeepEP High-Throughput kernels.
1920
"""
2021

22+
@staticmethod
23+
def maybe_roundup_layer_hidden_size(hidden_size: int,
24+
dtype: torch.dtype) -> int:
25+
# Round up hidden size so it is compatible with DeepEP High Throughput
26+
# kernels.
27+
# DeepEP intranode kernels make copies in units of,
28+
# 32(warp-size) int4 elements. Round up hidden size to respect this.
29+
# For example, an input hidden size of 2880 with dtype torch.bfloat16
30+
# will be rounded up to 3072.
31+
hidden_size_bytes = hidden_size * dtype.itemsize
32+
xfer_atom_size = 512 # 32 * 16 (size(int4))
33+
if hidden_size_bytes % xfer_atom_size == 0:
34+
return hidden_size
35+
36+
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
37+
return hidden_size_bytes // dtype.itemsize
38+
2139
def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
2240
dp_size: int, rank_expert_offset: int):
2341
super().__init__()

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 140 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm.model_executor.layers.fused_moe.config import (
1010
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
1111
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
12-
TopKWeightAndReduceDelegate)
12+
TopKWeightAndReduceNoOP)
13+
from vllm.triton_utils import tl, triton
1314
from vllm.utils import has_triton_kernels
1415

1516
logger = init_logger(__name__)
@@ -19,13 +20,55 @@
1920
import triton_kernels.swiglu
2021
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
2122
matmul_ogs)
22-
from triton_kernels.routing import routing
23+
from triton_kernels.routing import (RoutingData, routing,
24+
routing_from_bitmatrix)
25+
from triton_kernels.tensor import Bitmatrix
2326
except (ModuleNotFoundError, AttributeError) as e:
2427
logger.error(
2528
"Failed to import Triton kernels. Please make sure your triton "
2629
"version is compatible. Error: %s", e)
2730

2831

32+
@triton.jit
33+
def pack_bitmatrix(
34+
bitmatrix,
35+
topk_ids,
36+
n_rows, # n_rows in bitmatrix / topk_ids
37+
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
38+
n_expts_act, # num_topk
39+
BLOCK_SIZE_M: tl.constexpr,
40+
BLOCK_SIZE_K: tl.constexpr,
41+
):
42+
"""
43+
Packs topk_ids into a bitmatrix.
44+
code reference:
45+
https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
46+
"""
47+
pid_m = tl.program_id(0)
48+
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
49+
offsets_k = tl.arange(0, BLOCK_SIZE_K)
50+
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
51+
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
52+
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
53+
div = indices // 32
54+
rem = indices % 32
55+
one = tl.cast(1, tl.uint32)
56+
57+
# Iterate through all the relevant bitmatrix columns.
58+
for i in range(bm_cols):
59+
# When BLOCK_SIZE_K=32, offs is just the column index.
60+
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
61+
# All topks that need to go into this column has the correct bit set.
62+
# Other bits are 0. x is a 2D tensor.
63+
x = tl.where(div[:, :, None] == offs[None, None, :],
64+
(one << rem)[:, :, None], 0)
65+
# Reduce x to get a single int32_t bitpack.
66+
y = tl.reduce_or(x, axis=1)
67+
bitmatrix_ptrs = bitmatrix + offsets_m[:,
68+
None] * bm_cols + offs[None, :]
69+
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
70+
71+
2972
def triton_kernel_moe_forward(
3073
hidden_states: torch.Tensor,
3174
w1, # Tensor or triton_kernels.Tensor
@@ -124,48 +167,99 @@ def triton_kernel_fused_experts(
124167
return intermediate_cache3
125168

126169

127-
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
170+
def make_routing_data(
171+
topk_ids: torch.Tensor,
172+
topk_weights: torch.Tensor,
173+
num_local_experts: int,
174+
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
175+
176+
topk_ids = topk_ids.to(torch.int16)
177+
topk_weights = topk_weights.to(torch.bfloat16)
178+
179+
n_rows, num_topk = topk_ids.size()
180+
181+
BLOCK_SIZE_M = 512
182+
BLOCK_SIZE_K = 32
183+
184+
bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
185+
bitmatrix = torch.zeros((n_rows, bm_cols),
186+
dtype=torch.uint32,
187+
device=topk_ids.device)
188+
189+
grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), )
190+
pack_bitmatrix[grid](
191+
bitmatrix,
192+
topk_ids,
193+
n_rows,
194+
bm_cols,
195+
num_topk,
196+
BLOCK_SIZE_M=BLOCK_SIZE_M,
197+
BLOCK_SIZE_K=BLOCK_SIZE_K,
198+
)
199+
200+
bitmatrix_shape = [n_rows, bm_cols * 32]
201+
bitmatrix_shape_max = [n_rows, None]
202+
bitmatrix = Bitmatrix(bitmatrix,
203+
shape=bitmatrix_shape,
204+
shape_max=bitmatrix_shape_max,
205+
scratchpad=None)
206+
207+
# matmul_ogs expects invalid topk_weights to be -1s
208+
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
209+
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
210+
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk)
211+
212+
return routing_data, gather_indx, scatter_indx
213+
214+
215+
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
216+
217+
def __init__(self, quant_config: FusedMoEQuantConfig):
218+
super().__init__(quant_config)
219+
220+
def supports_expert_map(self) -> bool:
221+
return True
222+
223+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
224+
# Weight application and reduction happens in the fused_experts kernel.
225+
return TopKWeightAndReduceNoOP()
128226

129-
def __init__(
227+
def _make_routing_data(
130228
self,
131-
max_num_tokens: int,
132-
num_dispatchers: int,
133-
quant_config: FusedMoEQuantConfig,
134-
):
229+
topk_ids: torch.Tensor,
230+
topk_weights: torch.Tensor,
231+
num_local_experts: int,
232+
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
233+
return make_routing_data(topk_ids, topk_weights, num_local_experts)
234+
235+
236+
class OAITritonExperts(BaseOAITritonExperts):
237+
238+
def __init__(self, quant_config: FusedMoEQuantConfig):
239+
# TODO (varun) : Enable activation quantization
240+
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
135241
super().__init__(quant_config)
136-
self.max_num_tokens = max_num_tokens
137-
self.num_dispatchers = num_dispatchers
138242

139243
@property
140244
def activation_formats(
141245
self
142246
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
143-
return (mk.FusedMoEActivationFormat.BatchedExperts,
144-
mk.FusedMoEActivationFormat.BatchedExperts)
247+
return (mk.FusedMoEActivationFormat.Standard,
248+
mk.FusedMoEActivationFormat.Standard)
145249

146250
def supports_chunking(self) -> bool:
147-
return False
148-
149-
def supports_expert_map(self) -> bool:
150-
return False
151-
152-
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
153-
# Let PrepareAndFinalize::finalize() decide the impl.
154-
return TopKWeightAndReduceDelegate()
251+
return True
155252

156253
def workspace_shapes(
157254
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
158255
topk: int, global_num_experts: int, local_num_experts: int,
159256
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
160257
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
161258
# workspace are allocated inside the kernel
162-
assert a.dim() == 2
163-
num_dp = self.num_dispatchers
164-
num_experts = local_num_experts
165-
max_num_tokens = self.max_num_tokens
166-
workspace2 = (0, 0, 0)
167-
output = (num_experts, max_num_tokens * num_dp, N)
168-
return (output, workspace2, output, a.dtype)
259+
workspace1 = (M, K)
260+
workspace2 = (0, 0)
261+
output = (M, K)
262+
return (workspace1, workspace2, output, a.dtype)
169263

170264
def apply(
171265
self,
@@ -185,17 +279,29 @@ def apply(
185279
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
186280
apply_router_weight_on_input: bool,
187281
):
188-
return triton_kernel_fused_experts(
189-
output,
282+
if expert_map is not None:
283+
topk_ids = expert_map[topk_ids]
284+
285+
local_num_experts = w1.size(0)
286+
if global_num_experts == -1:
287+
global_num_experts = local_num_experts
288+
289+
routing_data, gather_indx, scatter_indx = self._make_routing_data(
290+
topk_ids, topk_weights, local_num_experts)
291+
292+
experts_output = triton_kernel_fused_experts(
293+
None,
190294
hidden_states,
191295
w1,
192296
w2,
193-
routing_data=None,
194-
gather_indx=None,
195-
scatter_indx=None,
297+
routing_data,
298+
gather_indx,
299+
scatter_indx,
196300
activation=activation,
197301
quant_config=self.quant_config,
198302
apply_router_weight_on_input=False,
199-
global_num_experts=global_num_experts,
200-
expert_map=expert_map,
303+
global_num_experts=local_num_experts,
304+
expert_map=None, # applied already
201305
a1q_scale=a1q_scale)
306+
307+
output.copy_(experts_output, non_blocking=True)

0 commit comments

Comments
 (0)