-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Modularize fused experts and integrate PPLX kernels #15956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
simon-mo
merged 205 commits into
vllm-project:main
from
neuralmagic:modular-fused-experts
May 14, 2025
Merged
Changes from all commits
Commits
Show all changes
205 commits
Select commit
Hold shift + click to select a range
72ee0c4
moe refactoring
bnellnm 24ca1f8
module deepgemm moe working
bnellnm 1281d8d
working deep gemm, wip cutlass
bnellnm 9cac3d1
working cutlass
bnellnm 08e3f07
deepgemm working again
bnellnm b46beb3
cutlass working again
bnellnm 80b3e20
cutlass working again
bnellnm a8911e8
fix inplace, format and name cleanups
bnellnm 01125b5
fix inplace, format + name cleanups
bnellnm 4207795
test improvements
bnellnm 5e445bc
make modular triton classes, fix edge cases
bnellnm a530fe3
fix outplace bug
bnellnm 5ec0f7c
refactor dispatch/combine stuff
bnellnm 1b25145
initial pplx dispatch/combine class
bnellnm 377dfd0
merge triton dispatch into standard, add some comments
bnellnm e0fd915
format
bnellnm 92da2f7
comments
bnellnm bec3835
fix linter
bnellnm ac8158b
fix more linter stuff
bnellnm b5d08aa
cleanup for review
bnellnm 3925993
review comments
bnellnm 0ad6d68
forgot return
bnellnm 97ac838
add dp_rank_num_tokens to DPMetadata
bnellnm 65b3169
better check for fp8 in _fp8_permute
bnellnm 8b06b48
updates
bnellnm 04fec22
fix merge issues
bnellnm bc4f7b0
fix lint
bnellnm 35a1381
add pplx tests
bnellnm 9fb396b
lint
bnellnm 92a9305
undo random lint changes
bnellnm 0ddd5f9
more lint
bnellnm 6cd718a
more lint nonsense
bnellnm dcd5926
WIP torch while
tlrmchlsmth 5c3d8b5
wip
tlrmchlsmth 59aeb5d
wip
tlrmchlsmth 9baf725
wip
tlrmchlsmth d7b5240
wip
tlrmchlsmth f6c87da
WIP integration
tlrmchlsmth 692008b
Add test for deep gemm matmul
bnellnm a707ba0
fix matmul test
bnellnm c35423d
running
bnellnm 02c9c07
wip
bnellnm f56b199
wip
bnellnm 3da73b6
debugging
bnellnm 1b2ace5
debugging
bnellnm 0666fe8
fix
bnellnm 47a3789
update deep gemm
bnellnm 66a7db0
update deep gemm + small test case
bnellnm 24d22db
wip
bnellnm 1498c7d
wip
bnellnm 5f0e563
problem with scores
bnellnm d446e2e
some passing tests
bnellnm 2b3a848
some passing tests
bnellnm 3cba397
topk > 1 doesn't work. prune oom-ing tests
bnellnm 91bff40
fix indices
bnellnm f7658b4
enable more tests
bnellnm 673a5f2
format
bnellnm 5b40f71
use fused_topk for unit test
bnellnm f0315e9
every other block correct
bnellnm 3d6b792
working
bnellnm 9a01d43
enable more tests
bnellnm da45726
working tests w/permute
bnellnm c4a89fd
cleanups
bnellnm f8779ad
wip
bnellnm 0b3ff3d
not crashing
bnellnm e6a9c50
baseline working integration
bnellnm 252115f
add allow_deep_gemm flag
bnellnm 53b7301
wip
bnellnm c87af3d
better
bnellnm fe6799b
fix some stuff
bnellnm 5921a4b
fix more stuff
bnellnm 1a7b675
cleanups
bnellnm e2828f6
some integration tests working
bnellnm f2d0bbe
almost all tests passing
bnellnm 3eb2185
cleanup temp construction a bit
bnellnm 297ac81
fix rest of tests
bnellnm 81b48ec
cleanups + format
bnellnm 42e1699
do more of output computation in place
bnellnm 70947dd
add env var
bnellnm 25aef1f
formatting, remove some blocking restrictions
bnellnm 719362a
wip
bnellnm 38dc3cf
fix resizing of output
bnellnm 3e8591e
fix resizing of output
bnellnm 916bfe1
fixes
bnellnm 2d534ae
aligned chunking working for deep gemm
bnellnm 5da8846
unaligned chunking for deep gemm
bnellnm 4726f6f
cleanup wip
bnellnm 7495946
clean up some blocking stuff
bnellnm 2ea300a
clean up some blocking stuff
bnellnm 9752886
tweaks
bnellnm 9f71c94
fix rebase
bnellnm d312986
rebase
bnellnm 93dfaf3
refactoring + minor perf improvements
bnellnm e2ebf14
refactoring + perf tweaks
bnellnm 6caebc0
remove debugging cruft
bnellnm 2f459a3
cache resize refactoring
bnellnm c88a17f
cleanups
bnellnm 2f56ff9
format
bnellnm 48d071f
revert test.txt, fix mypy errors
bnellnm a51970b
review comments
bnellnm 6676f24
review comments
bnellnm be58664
clean up use_dg flags
bnellnm a52f17a
remove check for aligned M
bnellnm f22b693
rebase + clean up test
bnellnm 549a9fe
fix format
bnellnm 8a72a9c
Clean up diff
tlrmchlsmth 005c18d
[Distributed] Add custom allreduce support for ROCM (#14125)
ilmarkov c98aa16
[Bugfix][Model] fix mllama multi-image (#14883)
yma11 2e7db9a
module deepgemm moe working
bnellnm f86e516
working deep gemm, wip cutlass
bnellnm b9d5e60
working cutlass
bnellnm e252cdf
deepgemm working again
bnellnm 802203b
fix inplace, format and name cleanups
bnellnm 1e63491
test improvements
bnellnm 16c2583
make modular triton classes, fix edge cases
bnellnm 5d88a64
refactor dispatch/combine stuff
bnellnm fa69484
initial pplx dispatch/combine class
bnellnm 27e92fb
merge triton dispatch into standard, add some comments
bnellnm 3381df0
format
bnellnm 734b06c
cleanup for review
bnellnm 834ea30
hacking
bnellnm 4504a8e
hacking
bnellnm 456ecc5
init stuff
bnellnm 4087600
call super ctor + fix random stuff
bnellnm bdb28ff
fix use_ep bug
tlrmchlsmth 1b6c4a2
fixes
tlrmchlsmth 9474cd7
get a bit further
bnellnm 0a80345
hacking in dispatch_combine
bnellnm 7a0d68b
hook up some wires
bnellnm bcf237c
seems to be working
bnellnm ee86b51
wip
bnellnm 4e22d15
batched moe test
bnellnm c76c988
simple test
bnellnm e3385da
cleanup
bnellnm 71f7361
test pplx w/naive implementation
bnellnm b1c40b7
test pplx w/naive implementation
bnellnm 3054ec2
hack fix for chunking loop
bnellnm be9a445
wip. add pplx unit test
bnellnm ce67d8d
work on unit test
bnellnm 95cd250
dispatch/combine unit test
bnellnm 56f8a6d
forgot file
bnellnm 4f7b6c9
somewhat working unit test
bnellnm add77e4
wip
bnellnm 9f7cc1e
fix test
bnellnm 27de9fe
some cleanup
bnellnm e4642cd
wip
bnellnm 15aa7df
wip
bnellnm 66c497f
undo random changes
bnellnm c5fec1a
merge
bnellnm 320805e
tweak
bnellnm 5dc242f
revert hack
bnellnm 0b7f124
fixes
bnellnm 6f192ec
pplx update
bnellnm 1bffb6b
varun's fixes
bnellnm 489c7df
varun's fixes
bnellnm 8253ede
tweak bound_m
bnellnm 43ed0ae
run linter
bnellnm 66558c7
more lint stuff
bnellnm 8098a69
add guards for pplx import
bnellnm 51ea5a3
fix forward_chunked
d0fe7b5
fix more lint
bnellnm 138ffc2
cleanups
bnellnm 269cccd
cleanups + lint, layer.py wip
bnellnm cbecf66
fix parallel_state lint
bnellnm 02f8201
fix M=1 pplx test
bnellnm 38f5b03
fix M=1 pplx test
bnellnm 829df83
fix M=1 pplx test
bnellnm 680c00e
lint
bnellnm 3e61124
remove valid pplx check
bnellnm 1cc3950
semi-working cudagraphs
bnellnm aaefc27
fix reference implementations
bnellnm 3b72bc5
wip ref impl
bnellnm 6bb6983
improve ref impl
bnellnm 909e0e5
wip
bnellnm c12dae1
fix merge
bnellnm b294ccd
fix merge
bnellnm 250f1b7
wip
5318218
zero out attn outputs during profile run
2e4be06
lint
bnellnm 70b9264
lint
bnellnm 69dbd31
revert lint changes to requirements/test.txt
bnellnm 31166d9
revert lint changes to compiler_interface.py
bnellnm 62a0896
fix merge
bnellnm cdef4c6
fix more lint errors
bnellnm 9f0ea4f
fix lint
bnellnm 6c0e085
cosmetic changes
bnellnm 54113c2
fix test
bnellnm 9f8e241
fix test
bnellnm a674762
Varun's fixes/cleanups
bnellnm 43e229c
review comments + cudagraph debugging
bnellnm ca2ff26
fix merge + add comments
bnellnm 93dd74f
lint
bnellnm b5be324
rename dispatch combine -> prepare finalize
bnellnm 9b97c83
review comments, only initialize pplx if EP is enabled
bnellnm d6e801e
fix test when pplx is missing + minor tweaks
bnellnm 9461d73
rename StandardPrepareAndFinalize
bnellnm 980262f
review comments
bnellnm 1cb6b1d
merge
bnellnm c5adb68
disable pplx for quantized types
bnellnm 40ebc47
revert MOE_DP_CHUNK_SIZE
bnellnm 484fc83
revert some bad changes
bnellnm c4086d7
rebase + fix some tests
bnellnm 3f10988
relax test_batched_moe tolerances
23cf129
Remove redundant tp_size setting in dbrx
1f91cfd
fix merge
bnellnm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import pytest | ||
| import torch | ||
| import triton.language as tl | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( | ||
| invoke_moe_batched_triton_kernel) | ||
|
|
||
|
|
||
| @dataclass | ||
| class BatchedMMConfig: | ||
| dtype: torch.dtype | ||
| num_experts: int | ||
| max_tokens_per_expert: int | ||
| K: int | ||
| N: int | ||
|
|
||
|
|
||
| @dataclass | ||
| class BatchedMMTensors: | ||
| A: torch.Tensor # [E, max_tokens, K] | ||
| B: torch.Tensor # [E, K, N] - column major | ||
| C: torch.Tensor # [E, max_tokens, N] | ||
| num_expert_tokens: torch.Tensor # [E] | ||
|
|
||
| @staticmethod | ||
| def make_tensors(config: BatchedMMConfig): | ||
| A = torch.randn( | ||
| (config.num_experts, config.max_tokens_per_expert, config.K), | ||
| device="cuda", | ||
| dtype=config.dtype) / 10 | ||
| B = torch.randn((config.num_experts, config.N, config.K), | ||
| device="cuda", | ||
| dtype=config.dtype) | ||
| C = torch.zeros( | ||
| (config.num_experts, config.max_tokens_per_expert, config.N), | ||
| device="cuda", | ||
| dtype=config.dtype) | ||
| num_expert_tokens = torch.randint(low=0, | ||
| high=config.max_tokens_per_expert, | ||
| size=(config.num_experts, ), | ||
| device="cuda", | ||
| dtype=torch.int32) | ||
| return BatchedMMTensors(A, B, C, num_expert_tokens) | ||
|
|
||
|
|
||
| def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, | ||
| num_expert_tokens: torch.Tensor) -> torch.Tensor: | ||
|
|
||
| num_expert_tokens_cpu = num_expert_tokens.clone() | ||
| num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") | ||
| num_experts = num_expert_tokens.size(0) | ||
|
|
||
| for e in range(num_experts): | ||
| num_tokens = num_expert_tokens_cpu[e] | ||
| C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) | ||
|
|
||
| return C | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_experts", [16, 32]) | ||
| @pytest.mark.parametrize("max_tokens_per_expert", | ||
| [32, 64, 128, 192, 224, 256, 512]) | ||
| @pytest.mark.parametrize("K", [128, 256, 1024]) | ||
| @pytest.mark.parametrize("N", [128, 256, 512, 1024]) | ||
| @pytest.mark.parametrize("dtype", | ||
| [torch.float32, torch.float16, torch.bfloat16]) | ||
| def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, | ||
| N: int, dtype: torch.dtype): | ||
|
|
||
| config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) | ||
| tensors = BatchedMMTensors.make_tensors(config) | ||
|
|
||
| test_output = tensors.C | ||
| ref_output = test_output.clone() | ||
|
|
||
| compute_tl_dtype = { | ||
| torch.float16: tl.float16, | ||
| torch.bfloat16: tl.bfloat16, | ||
| torch.float32: tl.float32 | ||
| }[test_output.dtype] | ||
| invoke_moe_batched_triton_kernel( | ||
| tensors.A, | ||
| tensors.B, | ||
| test_output, | ||
| tensors.num_expert_tokens, | ||
| compute_tl_dtype, | ||
| # Quantization data | ||
| None, | ||
| None, | ||
| None, | ||
| # Quantization schemes | ||
| False, | ||
| False, | ||
| False, | ||
| config={ | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 16, | ||
| "BLOCK_SIZE_K": 16 | ||
| }) | ||
|
|
||
| ref_output = ref_impl(tensors.A, tensors.B, ref_output, | ||
| tensors.num_expert_tokens) | ||
|
|
||
| rtol, atol = { | ||
| torch.float16: (6e-2, 6e-2), | ||
| torch.bfloat16: (6e-2, 6e-2), | ||
| torch.float32: (1e-2, 1e-2), | ||
| }[test_output.dtype] | ||
|
|
||
| torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.