Skip to content

Commit 15b3e65

Browse files
authored
refactor: communication module (#1162)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Separate communication kernels from different sources. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent f230eb6 commit 15b3e65

File tree

3 files changed

+44
-25
lines changed

3 files changed

+44
-25
lines changed

β€Žflashinfer/aot.pyβ€Ž

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .activation import act_func_def_str, gen_act_and_mul_module
1212
from .cascade import gen_cascade_module
13-
from .comm import gen_comm_module
13+
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
1414
from .fp4_quantization import gen_fp4_quantization_sm100_module
1515
from .fused_moe import gen_fused_moe_sm100_module
1616
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
@@ -327,7 +327,8 @@ def gen_all_modules(
327327

328328
jit_specs += [
329329
gen_cascade_module(),
330-
gen_comm_module(),
330+
gen_vllm_comm_module(),
331+
gen_trtllm_comm_module(),
331332
gen_norm_module(),
332333
gen_page_module(),
333334
gen_quantization_module(),

β€Žflashinfer/comm.pyβ€Ž

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,20 @@ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
264264
cudart = CudaRTLibrary()
265265

266266

267-
def gen_comm_module() -> JitSpec:
267+
def gen_vllm_comm_module() -> JitSpec:
268268
return gen_jit_spec(
269-
"comm",
269+
"vllm_comm",
270270
[
271271
jit_env.FLASHINFER_CSRC_DIR / "comm_pybind.cu",
272272
jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
273+
],
274+
)
275+
276+
277+
def gen_trtllm_comm_module() -> JitSpec:
278+
return gen_jit_spec(
279+
"trtllm_comm",
280+
[
273281
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu",
274282
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu",
275283
jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu",
@@ -279,8 +287,8 @@ def gen_comm_module() -> JitSpec:
279287

280288

281289
@functools.cache
282-
def get_comm_module():
283-
module = gen_comm_module().build_and_load()
290+
def get_vllm_comm_module():
291+
module = gen_vllm_comm_module().build_and_load()
284292

285293
# torch library for all
286294
@register_custom_op(
@@ -333,6 +341,21 @@ def all_reduce(
333341
) -> None:
334342
module.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes, num_ctas)
335343

344+
return SimpleNamespace(
345+
init_custom_ar=init_custom_ar,
346+
dispose=dispose,
347+
get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta,
348+
register_buffer=register_buffer,
349+
register_graph_buffers=register_graph_buffers,
350+
meta_size=meta_size,
351+
all_reduce=all_reduce,
352+
)
353+
354+
355+
@functools.cache
356+
def get_trtllm_comm_module():
357+
module = gen_trtllm_comm_module().build_and_load()
358+
336359
@register_custom_op(
337360
"flashinfer::trtllm_lamport_initialize", mutates_args=["buffer"]
338361
)
@@ -577,13 +600,6 @@ def trtllm_moe_allreduce_fusion(
577600
)
578601

579602
return SimpleNamespace(
580-
init_custom_ar=init_custom_ar,
581-
dispose=dispose,
582-
get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta,
583-
register_buffer=register_buffer,
584-
register_graph_buffers=register_graph_buffers,
585-
meta_size=meta_size,
586-
all_reduce=all_reduce,
587603
trtllm_lamport_initialize=trtllm_lamport_initialize,
588604
trtllm_lamport_initialize_all=trtllm_lamport_initialize_all,
589605
trtllm_custom_all_reduce=trtllm_custom_all_reduce,
@@ -595,11 +611,13 @@ def trtllm_moe_allreduce_fusion(
595611
def init_custom_ar(
596612
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
597613
) -> int:
598-
return get_comm_module().init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
614+
return get_vllm_comm_module().init_custom_ar(
615+
ipc_tensors, rank_data, rank, full_nvlink
616+
)
599617

600618

601619
def dispose(fa: int) -> None:
602-
get_comm_module().dispose(fa)
620+
get_vllm_comm_module().dispose(fa)
603621

604622

605623
def all_reduce(
@@ -621,27 +639,27 @@ def all_reduce(
621639
num_ctas: The number of CTAs to use for the all reduce.
622640
CTA upper bounds: 36. Generally, we can saturate the bandwidth even with small amount the SMs.
623641
"""
624-
get_comm_module().all_reduce(
642+
get_vllm_comm_module().all_reduce(
625643
fa, inp, out, reg_buffer, reg_buffer_sz_bytes, num_ctas
626644
)
627645

628646

629647
def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
630-
return get_comm_module().get_graph_buffer_ipc_meta(fa)
648+
return get_vllm_comm_module().get_graph_buffer_ipc_meta(fa)
631649

632650

633651
def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
634-
return get_comm_module().register_buffer(fa, fake_ipc_ptrs)
652+
return get_vllm_comm_module().register_buffer(fa, fake_ipc_ptrs)
635653

636654

637655
def register_graph_buffers(
638656
fa: int, handles: List[List[int]], offsets: List[List[int]]
639657
) -> None:
640-
get_comm_module().register_graph_buffers(fa, handles, offsets)
658+
get_vllm_comm_module().register_graph_buffers(fa, handles, offsets)
641659

642660

643661
def meta_size() -> int:
644-
return get_comm_module().meta_size()
662+
return get_vllm_comm_module().meta_size()
645663

646664

647665
def create_shared_buffer(
@@ -913,7 +931,7 @@ def pad_up(x, y):
913931

914932

915933
def trtllm_lamport_initialize(buffer_ptr: int, size: int, dtype: torch.dtype) -> None:
916-
get_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype)
934+
get_trtllm_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype)
917935

918936

919937
def trtllm_lamport_initialize_all(
@@ -923,7 +941,7 @@ def trtllm_lamport_initialize_all(
923941
size: int,
924942
dtype: torch.dtype,
925943
) -> None:
926-
get_comm_module().trtllm_lamport_initialize_all(
944+
get_trtllm_comm_module().trtllm_lamport_initialize_all(
927945
buffer_0_ptr, buffer_1_ptr, buffer_2_ptr, size, dtype
928946
)
929947

@@ -952,7 +970,7 @@ def trtllm_custom_all_reduce(
952970
lamport_peer_comm_buffer_ptrs_1: Optional[torch.Tensor],
953971
lamport_peer_comm_buffer_ptrs_2: Optional[torch.Tensor],
954972
) -> None:
955-
get_comm_module().trtllm_custom_all_reduce(
973+
get_trtllm_comm_module().trtllm_custom_all_reduce(
956974
inp,
957975
out,
958976
tp_size,
@@ -1001,7 +1019,7 @@ def trtllm_allreduce_fusion(
10011019
scale_factor: Optional[float],
10021020
layout_code: Optional[FP4QuantizationSFLayout],
10031021
) -> None:
1004-
get_comm_module().trtllm_allreduce_fusion(
1022+
get_trtllm_comm_module().trtllm_allreduce_fusion(
10051023
allreduce_in=allreduce_in,
10061024
world_size=world_size,
10071025
world_rank=world_rank,
@@ -1048,7 +1066,7 @@ def trtllm_moe_allreduce_fusion(
10481066
quant_out: Optional[torch.Tensor],
10491067
scale_out: Optional[torch.Tensor],
10501068
) -> None:
1051-
get_comm_module().trtllm_moe_allreduce_fusion(
1069+
get_trtllm_comm_module().trtllm_moe_allreduce_fusion(
10521070
world_size=world_size,
10531071
world_rank=world_rank,
10541072
token_num=token_num,
File renamed without changes.

0 commit comments

Comments
Β (0)