Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
205 commits
Select commit Hold shift + click to select a range
72ee0c4
moe refactoring
bnellnm Apr 1, 2025
24ca1f8
module deepgemm moe working
bnellnm Apr 1, 2025
1281d8d
working deep gemm, wip cutlass
bnellnm Apr 2, 2025
9cac3d1
working cutlass
bnellnm Apr 2, 2025
08e3f07
deepgemm working again
bnellnm Apr 2, 2025
b46beb3
cutlass working again
bnellnm Apr 2, 2025
80b3e20
cutlass working again
bnellnm Apr 2, 2025
a8911e8
fix inplace, format and name cleanups
bnellnm Apr 2, 2025
01125b5
fix inplace, format + name cleanups
bnellnm Apr 2, 2025
4207795
test improvements
bnellnm Apr 3, 2025
5e445bc
make modular triton classes, fix edge cases
bnellnm Apr 3, 2025
a530fe3
fix outplace bug
bnellnm Apr 3, 2025
5ec0f7c
refactor dispatch/combine stuff
bnellnm Apr 3, 2025
1b25145
initial pplx dispatch/combine class
bnellnm Apr 3, 2025
377dfd0
merge triton dispatch into standard, add some comments
bnellnm Apr 3, 2025
e0fd915
format
bnellnm Apr 3, 2025
92da2f7
comments
bnellnm Apr 3, 2025
bec3835
fix linter
bnellnm Apr 3, 2025
ac8158b
fix more linter stuff
bnellnm Apr 3, 2025
b5d08aa
cleanup for review
bnellnm Apr 3, 2025
3925993
review comments
bnellnm Apr 4, 2025
0ad6d68
forgot return
bnellnm Apr 4, 2025
97ac838
add dp_rank_num_tokens to DPMetadata
bnellnm Apr 4, 2025
65b3169
better check for fp8 in _fp8_permute
bnellnm Apr 4, 2025
8b06b48
updates
bnellnm Apr 28, 2025
04fec22
fix merge issues
bnellnm Apr 29, 2025
bc4f7b0
fix lint
bnellnm Apr 29, 2025
35a1381
add pplx tests
bnellnm Apr 29, 2025
9fb396b
lint
bnellnm Apr 29, 2025
92a9305
undo random lint changes
bnellnm Apr 29, 2025
0ddd5f9
more lint
bnellnm Apr 29, 2025
6cd718a
more lint nonsense
bnellnm Apr 29, 2025
dcd5926
WIP torch while
tlrmchlsmth Mar 15, 2025
5c3d8b5
wip
tlrmchlsmth Mar 25, 2025
59aeb5d
wip
tlrmchlsmth Mar 25, 2025
9baf725
wip
tlrmchlsmth Mar 27, 2025
d7b5240
wip
tlrmchlsmth Mar 27, 2025
f6c87da
WIP integration
tlrmchlsmth Mar 28, 2025
692008b
Add test for deep gemm matmul
bnellnm Feb 26, 2025
a707ba0
fix matmul test
bnellnm Feb 27, 2025
c35423d
running
bnellnm Feb 27, 2025
02c9c07
wip
bnellnm Feb 27, 2025
f56b199
wip
bnellnm Feb 28, 2025
3da73b6
debugging
bnellnm Feb 28, 2025
1b2ace5
debugging
bnellnm Feb 28, 2025
0666fe8
fix
bnellnm Feb 28, 2025
47a3789
update deep gemm
bnellnm Feb 28, 2025
66a7db0
update deep gemm + small test case
bnellnm Mar 1, 2025
24d22db
wip
bnellnm Mar 2, 2025
1498c7d
wip
bnellnm Mar 2, 2025
5f0e563
problem with scores
bnellnm Mar 2, 2025
d446e2e
some passing tests
bnellnm Mar 3, 2025
2b3a848
some passing tests
bnellnm Mar 3, 2025
3cba397
topk > 1 doesn't work. prune oom-ing tests
bnellnm Mar 3, 2025
91bff40
fix indices
bnellnm Mar 3, 2025
f7658b4
enable more tests
bnellnm Mar 3, 2025
673a5f2
format
bnellnm Mar 3, 2025
5b40f71
use fused_topk for unit test
bnellnm Mar 4, 2025
f0315e9
every other block correct
bnellnm Mar 5, 2025
3d6b792
working
bnellnm Mar 5, 2025
9a01d43
enable more tests
bnellnm Mar 5, 2025
da45726
working tests w/permute
bnellnm Mar 5, 2025
c4a89fd
cleanups
bnellnm Mar 5, 2025
f8779ad
wip
bnellnm Mar 6, 2025
0b3ff3d
not crashing
bnellnm Mar 6, 2025
e6a9c50
baseline working integration
bnellnm Mar 6, 2025
252115f
add allow_deep_gemm flag
bnellnm Mar 6, 2025
53b7301
wip
bnellnm Mar 7, 2025
c87af3d
better
bnellnm Mar 7, 2025
fe6799b
fix some stuff
bnellnm Mar 8, 2025
5921a4b
fix more stuff
bnellnm Mar 8, 2025
1a7b675
cleanups
bnellnm Mar 8, 2025
e2828f6
some integration tests working
bnellnm Mar 8, 2025
f2d0bbe
almost all tests passing
bnellnm Mar 10, 2025
3eb2185
cleanup temp construction a bit
bnellnm Mar 10, 2025
297ac81
fix rest of tests
bnellnm Mar 10, 2025
81b48ec
cleanups + format
bnellnm Mar 10, 2025
42e1699
do more of output computation in place
bnellnm Mar 10, 2025
70947dd
add env var
bnellnm Mar 10, 2025
25aef1f
formatting, remove some blocking restrictions
bnellnm Mar 12, 2025
719362a
wip
bnellnm Mar 12, 2025
38dc3cf
fix resizing of output
bnellnm Mar 12, 2025
3e8591e
fix resizing of output
bnellnm Mar 12, 2025
916bfe1
fixes
bnellnm Mar 12, 2025
2d534ae
aligned chunking working for deep gemm
bnellnm Mar 12, 2025
5da8846
unaligned chunking for deep gemm
bnellnm Mar 13, 2025
4726f6f
cleanup wip
bnellnm Mar 13, 2025
7495946
clean up some blocking stuff
bnellnm Mar 13, 2025
2ea300a
clean up some blocking stuff
bnellnm Mar 13, 2025
9752886
tweaks
bnellnm Mar 14, 2025
9f71c94
fix rebase
bnellnm Mar 15, 2025
d312986
rebase
bnellnm Mar 17, 2025
93dfaf3
refactoring + minor perf improvements
bnellnm Mar 21, 2025
e2ebf14
refactoring + perf tweaks
bnellnm Mar 22, 2025
6caebc0
remove debugging cruft
bnellnm Mar 24, 2025
2f459a3
cache resize refactoring
bnellnm Mar 24, 2025
c88a17f
cleanups
bnellnm Mar 25, 2025
2f56ff9
format
bnellnm Mar 25, 2025
48d071f
revert test.txt, fix mypy errors
bnellnm Mar 25, 2025
a51970b
review comments
bnellnm Mar 26, 2025
6676f24
review comments
bnellnm Mar 27, 2025
be58664
clean up use_dg flags
bnellnm Mar 27, 2025
a52f17a
remove check for aligned M
bnellnm Mar 27, 2025
f22b693
rebase + clean up test
bnellnm Mar 28, 2025
549a9fe
fix format
bnellnm Mar 28, 2025
8a72a9c
Clean up diff
tlrmchlsmth Mar 31, 2025
005c18d
[Distributed] Add custom allreduce support for ROCM (#14125)
ilmarkov Apr 1, 2025
c98aa16
[Bugfix][Model] fix mllama multi-image (#14883)
yma11 Apr 1, 2025
2e7db9a
module deepgemm moe working
bnellnm Apr 1, 2025
f86e516
working deep gemm, wip cutlass
bnellnm Apr 2, 2025
b9d5e60
working cutlass
bnellnm Apr 2, 2025
e252cdf
deepgemm working again
bnellnm Apr 2, 2025
802203b
fix inplace, format and name cleanups
bnellnm Apr 2, 2025
1e63491
test improvements
bnellnm Apr 3, 2025
16c2583
make modular triton classes, fix edge cases
bnellnm Apr 3, 2025
5d88a64
refactor dispatch/combine stuff
bnellnm Apr 3, 2025
fa69484
initial pplx dispatch/combine class
bnellnm Apr 3, 2025
27e92fb
merge triton dispatch into standard, add some comments
bnellnm Apr 3, 2025
3381df0
format
bnellnm Apr 3, 2025
734b06c
cleanup for review
bnellnm Apr 3, 2025
834ea30
hacking
bnellnm Apr 4, 2025
4504a8e
hacking
bnellnm Apr 7, 2025
456ecc5
init stuff
bnellnm Apr 7, 2025
4087600
call super ctor + fix random stuff
bnellnm Apr 7, 2025
bdb28ff
fix use_ep bug
tlrmchlsmth Apr 7, 2025
1b6c4a2
fixes
tlrmchlsmth Apr 7, 2025
9474cd7
get a bit further
bnellnm Apr 7, 2025
0a80345
hacking in dispatch_combine
bnellnm Apr 9, 2025
7a0d68b
hook up some wires
bnellnm Apr 10, 2025
bcf237c
seems to be working
bnellnm Apr 10, 2025
ee86b51
wip
bnellnm Apr 11, 2025
4e22d15
batched moe test
bnellnm Apr 14, 2025
c76c988
simple test
bnellnm Apr 15, 2025
e3385da
cleanup
bnellnm Apr 15, 2025
71f7361
test pplx w/naive implementation
bnellnm Apr 15, 2025
b1c40b7
test pplx w/naive implementation
bnellnm Apr 15, 2025
3054ec2
hack fix for chunking loop
bnellnm Apr 15, 2025
be9a445
wip. add pplx unit test
bnellnm Apr 16, 2025
ce67d8d
work on unit test
bnellnm Apr 17, 2025
95cd250
dispatch/combine unit test
bnellnm Apr 17, 2025
56f8a6d
forgot file
bnellnm Apr 17, 2025
4f7b6c9
somewhat working unit test
bnellnm Apr 18, 2025
add77e4
wip
bnellnm Apr 18, 2025
9f7cc1e
fix test
bnellnm Apr 18, 2025
27de9fe
some cleanup
bnellnm Apr 19, 2025
e4642cd
wip
bnellnm Apr 19, 2025
15aa7df
wip
bnellnm Apr 29, 2025
66c497f
undo random changes
bnellnm Apr 29, 2025
c5fec1a
merge
bnellnm Apr 29, 2025
320805e
tweak
bnellnm Apr 29, 2025
5dc242f
revert hack
bnellnm Apr 29, 2025
0b7f124
fixes
bnellnm Apr 29, 2025
6f192ec
pplx update
bnellnm Apr 29, 2025
1bffb6b
varun's fixes
bnellnm Apr 29, 2025
489c7df
varun's fixes
bnellnm Apr 29, 2025
8253ede
tweak bound_m
bnellnm Apr 29, 2025
43ed0ae
run linter
bnellnm Apr 29, 2025
66558c7
more lint stuff
bnellnm Apr 29, 2025
8098a69
add guards for pplx import
bnellnm Apr 30, 2025
51ea5a3
fix forward_chunked
Apr 30, 2025
d0fe7b5
fix more lint
bnellnm Apr 30, 2025
138ffc2
cleanups
bnellnm Apr 30, 2025
269cccd
cleanups + lint, layer.py wip
bnellnm Apr 30, 2025
cbecf66
fix parallel_state lint
bnellnm Apr 30, 2025
02f8201
fix M=1 pplx test
bnellnm May 1, 2025
38f5b03
fix M=1 pplx test
bnellnm May 1, 2025
829df83
fix M=1 pplx test
bnellnm May 1, 2025
680c00e
lint
bnellnm May 1, 2025
3e61124
remove valid pplx check
bnellnm May 1, 2025
1cc3950
semi-working cudagraphs
bnellnm May 2, 2025
aaefc27
fix reference implementations
bnellnm May 2, 2025
3b72bc5
wip ref impl
bnellnm May 5, 2025
6bb6983
improve ref impl
bnellnm May 6, 2025
909e0e5
wip
bnellnm May 6, 2025
c12dae1
fix merge
bnellnm May 6, 2025
b294ccd
fix merge
bnellnm May 6, 2025
250f1b7
wip
May 1, 2025
5318218
zero out attn outputs during profile run
May 7, 2025
2e4be06
lint
bnellnm May 7, 2025
70b9264
lint
bnellnm May 7, 2025
69dbd31
revert lint changes to requirements/test.txt
bnellnm May 7, 2025
31166d9
revert lint changes to compiler_interface.py
bnellnm May 7, 2025
62a0896
fix merge
bnellnm May 7, 2025
cdef4c6
fix more lint errors
bnellnm May 7, 2025
9f0ea4f
fix lint
bnellnm May 7, 2025
6c0e085
cosmetic changes
bnellnm May 8, 2025
54113c2
fix test
bnellnm May 8, 2025
9f8e241
fix test
bnellnm May 8, 2025
a674762
Varun's fixes/cleanups
bnellnm May 9, 2025
43e229c
review comments + cudagraph debugging
bnellnm May 12, 2025
ca2ff26
fix merge + add comments
bnellnm May 12, 2025
93dd74f
lint
bnellnm May 12, 2025
b5be324
rename dispatch combine -> prepare finalize
bnellnm May 12, 2025
9b97c83
review comments, only initialize pplx if EP is enabled
bnellnm May 12, 2025
d6e801e
fix test when pplx is missing + minor tweaks
bnellnm May 13, 2025
9461d73
rename StandardPrepareAndFinalize
bnellnm May 13, 2025
980262f
review comments
bnellnm May 13, 2025
1cb6b1d
merge
bnellnm May 13, 2025
c5adb68
disable pplx for quantized types
bnellnm May 13, 2025
40ebc47
revert MOE_DP_CHUNK_SIZE
bnellnm May 13, 2025
484fc83
revert some bad changes
bnellnm May 13, 2025
c4086d7
rebase + fix some tests
bnellnm May 13, 2025
3f10988
relax test_batched_moe tolerances
May 14, 2025
23cf129
Remove redundant tp_size setting in dbrx
May 14, 2025
1f91cfd
fix merge
bnellnm May 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
Expand Down
14 changes: 14 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
8 changes: 4 additions & 4 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}

if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
Expand All @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
Expand All @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
Expand All @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");

VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
Expand Down
63 changes: 45 additions & 18 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}

template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{

using cub_kvp = cub::KeyValuePair<int, float>;
Expand Down Expand Up @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/

template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
Expand Down Expand Up @@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail

template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
Expand All @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);

template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
Expand Down Expand Up @@ -493,14 +502,32 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);

if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}
22 changes: 16 additions & 6 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@ def parse_args():
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args()


def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
Expand Down Expand Up @@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens=[16, 20][global_dp_rank % 2])

# Create an LLM.
llm = LLM(model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True,
enable_expert_parallel=True)
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
Expand Down Expand Up @@ -155,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port,
tp_size))
tp_size, args.enforce_eager,
args.trust_remote_code))
proc.start()
procs.append(proc)
exit_code = 0
Expand Down
114 changes: 114 additions & 0 deletions tests/kernels/moe/test_batched_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import pytest
import torch
import triton.language as tl

from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)


@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int


@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]

@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)


def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:

num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)

for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)

return C


@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):

config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)

test_output = tensors.C
ref_output = test_output.clone()

compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})

ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)

rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]

torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
Loading