Skip to content

Commit ba2470c

Browse files
authored
feat: add finalize_moe_allreduce from trtllm (#1159)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description - add finalize_moe_allreduce from trtllm ## πŸ” Related Issues NVIDIA/TensorRT-LLM#4756 ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 15b3e65 commit ba2470c

File tree

5 files changed

+687
-4
lines changed

5 files changed

+687
-4
lines changed

β€Žcsrc/trtllm_moe_allreduce_fusion.cuβ€Ž

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,57 @@ void trtllm_moe_allreduce_fusion(
8181
});
8282
}
8383

84+
void trtllm_moe_finalize_allreduce_fusion(
85+
at::Tensor const& allreduce_in, at::Tensor const& residual_in, at::Tensor const& norm_weight,
86+
at::Tensor const& expanded_idx_to_permuted_idx, at::Tensor& norm_out, at::Tensor& residual_out,
87+
bool launch_with_pdl, at::Tensor& workspace, int64_t const world_rank, int64_t const world_size,
88+
double const eps, std::optional<at::Tensor> const& shared_expert_output,
89+
std::optional<at::Tensor> const& expert_scale_factor) {
90+
DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.scalar_type(), c_type, [&] {
91+
MoeFinalizeAllReduceFusionParams<c_type> params;
92+
93+
int hidden_dim = residual_in.size(-1);
94+
int top_k = expanded_idx_to_permuted_idx.size(-1);
95+
96+
params.quant_out = nullptr;
97+
params.scale_out = nullptr;
98+
99+
params.nranks = static_cast<int>(world_size);
100+
params.rank = static_cast<int>(world_rank);
101+
// size: num_token * hidden_dim
102+
params.size = residual_in.numel();
103+
params.hidden_dim = hidden_dim;
104+
105+
// workspace: AR scratch space
106+
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
107+
params.rms_gamma = norm_weight.data_ptr();
108+
params.rms_eps = static_cast<float>(eps);
109+
params.residual_in = residual_in.data_ptr();
110+
params.stream = at::cuda::getCurrentCUDAStream(norm_weight.get_device());
111+
112+
// MOE Reduction specific params
113+
params.top_k = top_k;
114+
params.allreduce_in = allreduce_in.data_ptr();
115+
params.expert_scale_factor =
116+
expert_scale_factor.has_value() ? expert_scale_factor.value().data_ptr() : nullptr;
117+
TORCH_CHECK(expanded_idx_to_permuted_idx.scalar_type() == at::ScalarType::Int,
118+
"expanded_idx_to_permuted_idx must be int32");
119+
params.expanded_idx_to_permuted_idx =
120+
static_cast<int32_t*>(expanded_idx_to_permuted_idx.data_ptr());
121+
params.shared_expert_output =
122+
shared_expert_output.has_value() ? shared_expert_output.value().data_ptr() : nullptr;
123+
124+
// output tensors
125+
params.norm_out = norm_out.mutable_data_ptr();
126+
params.residual_out = residual_out.mutable_data_ptr();
127+
128+
auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl);
129+
TORCH_CHECK(status == cudaSuccess, "moefinalize_allreduce_fusion_op failed with error code ",
130+
cudaGetErrorString(status));
131+
});
132+
}
133+
84134
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
85135
m.def("trtllm_moe_allreduce_fusion", &trtllm_moe_allreduce_fusion);
136+
m.def("trtllm_moe_finalize_allreduce_fusion", &trtllm_moe_finalize_allreduce_fusion);
86137
}

β€Žflashinfer/comm.pyβ€Ž

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class AllReduceFusionOp:
6666
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 6
6767
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 7
6868
MOE_ALLREDUCE_RESIDUAL_RMS_NORM = 8
69+
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 9
6970

7071

7172
class AllReduceFusionPattern:
@@ -599,12 +600,48 @@ def trtllm_moe_allreduce_fusion(
599600
scale_out,
600601
)
601602

603+
@register_custom_op(
604+
"flashinfer::trtllm_moe_finalize_allreduce_fusion",
605+
mutates_args=["residual_out", "norm_out"],
606+
)
607+
def trtllm_moe_finalize_allreduce_fusion(
608+
allreduce_in: torch.Tensor,
609+
residual_in: torch.Tensor,
610+
norm_weight: torch.Tensor,
611+
expanded_idx_to_permuted_idx: torch.Tensor,
612+
norm_out: torch.Tensor,
613+
residual_out: torch.Tensor,
614+
launch_with_pdl: bool,
615+
workspace: torch.Tensor,
616+
world_rank: int,
617+
world_size: int,
618+
eps: float,
619+
shared_expert_output: Optional[torch.Tensor],
620+
expert_scale_factor: Optional[torch.Tensor],
621+
) -> None:
622+
module.trtllm_moe_finalize_allreduce_fusion(
623+
allreduce_in,
624+
residual_in,
625+
norm_weight,
626+
expanded_idx_to_permuted_idx,
627+
norm_out,
628+
residual_out,
629+
launch_with_pdl,
630+
workspace,
631+
world_rank,
632+
world_size,
633+
eps,
634+
shared_expert_output,
635+
expert_scale_factor,
636+
)
637+
602638
return SimpleNamespace(
603639
trtllm_lamport_initialize=trtllm_lamport_initialize,
604640
trtllm_lamport_initialize_all=trtllm_lamport_initialize_all,
605641
trtllm_custom_all_reduce=trtllm_custom_all_reduce,
606642
trtllm_allreduce_fusion=trtllm_allreduce_fusion,
607643
trtllm_moe_allreduce_fusion=trtllm_moe_allreduce_fusion,
644+
trtllm_moe_finalize_allreduce_fusion=trtllm_moe_finalize_allreduce_fusion,
608645
)
609646

610647

@@ -1088,3 +1125,35 @@ def trtllm_moe_allreduce_fusion(
10881125
quant_out=quant_out,
10891126
scale_out=scale_out,
10901127
)
1128+
1129+
1130+
def trtllm_moe_finalize_allreduce_fusion(
1131+
allreduce_in: torch.Tensor,
1132+
residual_in: torch.Tensor,
1133+
norm_weight: torch.Tensor,
1134+
expanded_idx_to_permuted_idx: torch.Tensor,
1135+
norm_out: torch.Tensor,
1136+
residual_out: torch.Tensor,
1137+
workspace_ptrs: torch.Tensor,
1138+
launch_with_pdl: bool,
1139+
world_rank: int,
1140+
world_size: int,
1141+
eps: float,
1142+
shared_expert_output: Optional[torch.Tensor],
1143+
expert_scale_factor: Optional[torch.Tensor],
1144+
) -> None:
1145+
get_trtllm_comm_module().trtllm_moe_finalize_allreduce_fusion(
1146+
allreduce_in=allreduce_in,
1147+
residual_in=residual_in,
1148+
norm_weight=norm_weight,
1149+
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
1150+
norm_out=norm_out,
1151+
residual_out=residual_out,
1152+
workspace=workspace_ptrs,
1153+
launch_with_pdl=launch_with_pdl,
1154+
world_rank=world_rank,
1155+
world_size=world_size,
1156+
eps=eps,
1157+
shared_expert_output=shared_expert_output,
1158+
expert_scale_factor=expert_scale_factor,
1159+
)

β€Žinclude/flashinfer/comm/trtllm_allreduce.cuhβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ enum class AllReduceFusionOp : int8_t {
7474
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 6,
7575
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 7,
7676
MOE_ALLREDUCE_RESIDUAL_RMS_NORM = 8,
77+
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 9,
7778
};
7879

7980
template <typename T>

0 commit comments

Comments
Β (0)