Skip to content

Commit e3ae076

Browse files
authored
Merge pull request vllm-project#14 from ROCm/fused_topK_softmax
enable fused topK_softmax kernel for hip path
2 parents 83ce7b2 + 41e348a commit e3ae076

File tree

4 files changed

+137
-95
lines changed

4 files changed

+137
-95
lines changed

csrc/cuda_compat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
#ifndef USE_ROCM
2020
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
21+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
2122
#else
2223
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
24+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width)
2325
#endif
2426

2527
#ifndef USE_ROCM

csrc/moe/topk_softmax_kernels.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,22 @@
1919
#include <torch/extension.h>
2020
#include <ATen/cuda/CUDAContext.h>
2121
#include <c10/cuda/CUDAGuard.h>
22+
#include "csrc/cuda_compat.h"
2223

23-
#include <cub/cub.cuh>
24-
#include <cub/util_type.cuh>
24+
#ifndef USE_ROCM
25+
#include <cub/util_type.cuh>
26+
#include <cub/cub.cuh>
27+
#else
28+
#include <hipcub/util_type.hpp>
29+
#include <hipcub/hipcub.hpp>
30+
#endif
31+
32+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
33+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
2534

2635
namespace vllm {
2736
namespace moe {
2837

29-
static constexpr int WARP_SIZE = 32;
30-
3138
/// Aligned array type
3239
template <
3340
typename T,
@@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
265272
#pragma unroll
266273
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
267274
{
268-
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
275+
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
269276
}
270277

271278
// From this point, thread max in all the threads have the max within the row.
@@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
282289
#pragma unroll
283290
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
284291
{
285-
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
292+
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
286293
}
287294

288295
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
@@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
332339
#pragma unroll
333340
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
334341
{
335-
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
336-
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
342+
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
343+
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
337344

338345
// We want lower indices to "win" in every thread so we break ties this way
339346
if (other_max > max_val || (other_max == max_val && other_expert < expert))
@@ -383,7 +390,7 @@ struct TopkConstants
383390
{
384391
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
385392
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
386-
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
393+
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
387394
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
388395
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
389396
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
@@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
396403
{
397404
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
398405

399-
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
406+
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
400407
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
401408
static constexpr int VPT = Constants::VPT;
402409
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;

setup.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,13 @@ def _read_requirements(filename: str) -> List[str]:
362362

363363
ext_modules = []
364364

365-
if _is_cuda():
366-
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
367-
368-
if _install_punica():
369-
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
365+
if _is_cuda() and _install_punica():
366+
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
370367

371368
if not _is_neuron():
372369
ext_modules.append(CMakeExtension(name="vllm._C"))
373370
ext_modules.append(CMakeExtension(name="vllm._custom_C"))
371+
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
374372

375373
package_data = {
376374
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 115 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import triton
99
import triton.language as tl
1010

11+
import vllm._moe_C as moe_kernels
1112
from vllm._C import ops
1213
from vllm.logger import init_logger
13-
from vllm.utils import is_hip
1414

1515
logger = init_logger(__name__)
1616

@@ -108,8 +108,8 @@ def fused_moe_kernel(
108108
offs_k[None, :] * stride_ak)
109109

110110
off_experts = tl.load(expert_ids_ptr + pid_m)
111-
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
112-
offs_bn[None, :] * stride_bn)
111+
b_ptrs = (b_ptr + off_experts * stride_be +
112+
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
113113

114114
# -----------------------------------------------------------
115115
# Iterate to compute a block of the C matrix.
@@ -121,10 +121,12 @@ def fused_moe_kernel(
121121
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
122122
# Load the next block of A and B, generate a mask by checking the
123123
# K dimension.
124-
a = tl.load(a_ptrs,
125-
mask=token_mask[:, None] &
126-
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
127-
other=0.0)
124+
a = tl.load(
125+
a_ptrs,
126+
mask=token_mask[:, None] &
127+
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
128+
other=0.0,
129+
)
128130
b = tl.load(b_ptrs,
129131
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
130132
other=0.0)
@@ -144,8 +146,8 @@ def fused_moe_kernel(
144146
# -----------------------------------------------------------
145147
# Write back the block of the output
146148
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
147-
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
148-
None, :]
149+
c_ptrs = (c_ptr + stride_cm * offs_token[:, None] +
150+
stride_cn * offs_cn[None, :])
149151
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
150152
tl.store(c_ptrs, accumulator, mask=c_mask)
151153

@@ -193,31 +195,46 @@ def moe_align_block_size(
193195
sorted_ids = torch.empty(
194196
(topk_ids.numel() + num_experts * (block_size - 1), ),
195197
dtype=torch.int32,
196-
device=topk_ids.device)
197-
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
198-
dtype=torch.int32,
199-
device=topk_ids.device)
198+
device=topk_ids.device,
199+
)
200+
expert_ids = torch.empty(
201+
(topk_ids.numel() + num_experts, ),
202+
dtype=torch.int32,
203+
device=topk_ids.device,
204+
)
200205
sorted_ids.fill_(topk_ids.numel())
201206
num_tokens_post_pad = torch.empty((1),
202207
dtype=torch.int32,
203208
device=topk_ids.device)
204-
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
205-
expert_ids, num_tokens_post_pad)
209+
ops.moe_align_block_size(
210+
topk_ids,
211+
num_experts,
212+
block_size,
213+
sorted_ids,
214+
expert_ids,
215+
num_tokens_post_pad,
216+
)
206217
return sorted_ids, expert_ids, num_tokens_post_pad
207218

208219

209-
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
210-
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
211-
sorted_token_ids: torch.Tensor,
212-
expert_ids: torch.Tensor,
213-
num_tokens_post_padded: torch.Tensor,
214-
mul_routed_weight: bool, top_k: int,
215-
config: Dict[str, Any]) -> None:
220+
def invoke_fused_moe_kernel(
221+
A: torch.Tensor,
222+
B: torch.Tensor,
223+
C: torch.Tensor,
224+
topk_weights: torch.Tensor,
225+
topk_ids: torch.Tensor,
226+
sorted_token_ids: torch.Tensor,
227+
expert_ids: torch.Tensor,
228+
num_tokens_post_padded: torch.Tensor,
229+
mul_routed_weight: bool,
230+
top_k: int,
231+
config: Dict[str, Any],
232+
) -> None:
216233
assert topk_weights.stride(1) == 1
217234
assert sorted_token_ids.stride(0) == 1
218235

219236
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
220-
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
237+
"BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), )
221238

222239
fused_moe_kernel[grid](
223240
A,
@@ -310,8 +327,8 @@ def fused_moe(
310327
- torch.Tensor: The output tensor after applying the MoE layer.
311328
"""
312329
# Check constraints.
313-
assert hidden_states.shape[0] == gating_output.shape[0], (
314-
"Number of tokens mismatch")
330+
assert (hidden_states.shape[0] == gating_output.shape[0]
331+
), "Number of tokens mismatch"
315332
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
316333
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
317334
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@@ -323,34 +340,26 @@ def fused_moe(
323340
M, _ = hidden_states.shape
324341
E, N, _ = w1.shape
325342

326-
if is_hip():
327-
# The MoE kernels are not yet supported on ROCm.
328-
routing_weights = torch.softmax(gating_output,
329-
dim=-1,
330-
dtype=torch.float32)
331-
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
332-
else:
333-
import vllm._moe_C as moe_kernels
334-
335-
topk_weights = torch.empty(M,
336-
topk,
337-
dtype=torch.float32,
338-
device=hidden_states.device)
339-
topk_ids = torch.empty(M,
343+
topk_weights = torch.empty(M,
340344
topk,
341-
dtype=torch.int32,
345+
dtype=torch.float32,
342346
device=hidden_states.device)
343-
token_expert_indicies = torch.empty(M,
344-
topk,
345-
dtype=torch.int32,
346-
device=hidden_states.device)
347-
moe_kernels.topk_softmax(
348-
topk_weights,
349-
topk_ids,
350-
token_expert_indicies,
351-
gating_output.float(), # TODO(woosuk): Optimize this.
352-
)
353-
del token_expert_indicies # Not used. Will be used in the future.
347+
topk_ids = torch.empty(M,
348+
topk,
349+
dtype=torch.int32,
350+
device=hidden_states.device)
351+
token_expert_indicies = torch.empty(M,
352+
topk,
353+
dtype=torch.int32,
354+
device=hidden_states.device)
355+
moe_kernels.topk_softmax(
356+
topk_weights,
357+
topk_ids,
358+
token_expert_indicies,
359+
gating_output.float(), # TODO(woosuk): Optimize this.
360+
)
361+
del token_expert_indicies # Not used. Will be used in the future.
362+
354363
if renormalize:
355364
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
356365

@@ -367,48 +376,74 @@ def fused_moe(
367376
else:
368377
# Else use the default config
369378
config = {
370-
'BLOCK_SIZE_M': 64,
371-
'BLOCK_SIZE_N': 64,
372-
'BLOCK_SIZE_K': 32,
373-
'GROUP_SIZE_M': 8
379+
"BLOCK_SIZE_M": 64,
380+
"BLOCK_SIZE_N": 64,
381+
"BLOCK_SIZE_K": 32,
382+
"GROUP_SIZE_M": 8,
374383
}
375384

376385
if M <= E:
377386
config = {
378-
'BLOCK_SIZE_M': 16,
379-
'BLOCK_SIZE_N': 32,
380-
'BLOCK_SIZE_K': 64,
381-
'GROUP_SIZE_M': 1
387+
"BLOCK_SIZE_M": 16,
388+
"BLOCK_SIZE_N": 32,
389+
"BLOCK_SIZE_K": 64,
390+
"GROUP_SIZE_M": 1,
382391
}
383392

384-
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
385-
device=hidden_states.device,
386-
dtype=hidden_states.dtype)
387-
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
388-
device=hidden_states.device,
389-
dtype=hidden_states.dtype)
390-
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
391-
device=hidden_states.device,
392-
dtype=hidden_states.dtype)
393+
intermediate_cache1 = torch.empty(
394+
(M, topk_ids.shape[1], N),
395+
device=hidden_states.device,
396+
dtype=hidden_states.dtype,
397+
)
398+
intermediate_cache2 = torch.empty(
399+
(M * topk_ids.shape[1], N // 2),
400+
device=hidden_states.device,
401+
dtype=hidden_states.dtype,
402+
)
403+
intermediate_cache3 = torch.empty(
404+
(M, topk_ids.shape[1], w2.shape[1]),
405+
device=hidden_states.device,
406+
dtype=hidden_states.dtype,
407+
)
393408

394409
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
395-
topk_ids, config['BLOCK_SIZE_M'], E)
410+
topk_ids, config["BLOCK_SIZE_M"], E)
396411

397-
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
398-
topk_weights, topk_ids, sorted_token_ids,
399-
expert_ids, num_tokens_post_padded, False,
400-
topk_ids.shape[1], config)
412+
invoke_fused_moe_kernel(
413+
hidden_states,
414+
w1,
415+
intermediate_cache1,
416+
topk_weights,
417+
topk_ids,
418+
sorted_token_ids,
419+
expert_ids,
420+
num_tokens_post_padded,
421+
False,
422+
topk_ids.shape[1],
423+
config,
424+
)
401425

402426
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
403427

404-
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
405-
topk_weights, topk_ids, sorted_token_ids,
406-
expert_ids, num_tokens_post_padded, True, 1,
407-
config)
428+
invoke_fused_moe_kernel(
429+
intermediate_cache2,
430+
w2,
431+
intermediate_cache3,
432+
topk_weights,
433+
topk_ids,
434+
sorted_token_ids,
435+
expert_ids,
436+
num_tokens_post_padded,
437+
True,
438+
1,
439+
config,
440+
)
408441

409442
if inplace:
410-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
411-
dim=1,
412-
out=hidden_states)
443+
return torch.sum(
444+
intermediate_cache3.view(*intermediate_cache3.shape),
445+
dim=1,
446+
out=hidden_states,
447+
)
413448
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
414449
dim=1)

0 commit comments

Comments
 (0)