@@ -264,12 +264,20 @@ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
264264cudart = CudaRTLibrary ()
265265
266266
267- def gen_comm_module () -> JitSpec :
267+ def gen_vllm_comm_module () -> JitSpec :
268268 return gen_jit_spec (
269- "comm " ,
269+ "vllm_comm " ,
270270 [
271271 jit_env .FLASHINFER_CSRC_DIR / "comm_pybind.cu" ,
272272 jit_env .FLASHINFER_CSRC_DIR / "custom_all_reduce.cu" ,
273+ ],
274+ )
275+
276+
277+ def gen_trtllm_comm_module () -> JitSpec :
278+ return gen_jit_spec (
279+ "trtllm_comm" ,
280+ [
273281 jit_env .FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu" ,
274282 jit_env .FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu" ,
275283 jit_env .FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu" ,
@@ -279,8 +287,8 @@ def gen_comm_module() -> JitSpec:
279287
280288
281289@functools .cache
282- def get_comm_module ():
283- module = gen_comm_module ().build_and_load ()
290+ def get_vllm_comm_module ():
291+ module = gen_vllm_comm_module ().build_and_load ()
284292
285293 # torch library for all
286294 @register_custom_op (
@@ -333,6 +341,21 @@ def all_reduce(
333341 ) -> None :
334342 module .all_reduce (fa , inp , out , reg_buffer , reg_buffer_sz_bytes , num_ctas )
335343
344+ return SimpleNamespace (
345+ init_custom_ar = init_custom_ar ,
346+ dispose = dispose ,
347+ get_graph_buffer_ipc_meta = get_graph_buffer_ipc_meta ,
348+ register_buffer = register_buffer ,
349+ register_graph_buffers = register_graph_buffers ,
350+ meta_size = meta_size ,
351+ all_reduce = all_reduce ,
352+ )
353+
354+
355+ @functools .cache
356+ def get_trtllm_comm_module ():
357+ module = gen_trtllm_comm_module ().build_and_load ()
358+
336359 @register_custom_op (
337360 "flashinfer::trtllm_lamport_initialize" , mutates_args = ["buffer" ]
338361 )
@@ -577,13 +600,6 @@ def trtllm_moe_allreduce_fusion(
577600 )
578601
579602 return SimpleNamespace (
580- init_custom_ar = init_custom_ar ,
581- dispose = dispose ,
582- get_graph_buffer_ipc_meta = get_graph_buffer_ipc_meta ,
583- register_buffer = register_buffer ,
584- register_graph_buffers = register_graph_buffers ,
585- meta_size = meta_size ,
586- all_reduce = all_reduce ,
587603 trtllm_lamport_initialize = trtllm_lamport_initialize ,
588604 trtllm_lamport_initialize_all = trtllm_lamport_initialize_all ,
589605 trtllm_custom_all_reduce = trtllm_custom_all_reduce ,
@@ -595,11 +611,13 @@ def trtllm_moe_allreduce_fusion(
595611def init_custom_ar (
596612 ipc_tensors : List [int ], rank_data : torch .Tensor , rank : int , full_nvlink : bool
597613) -> int :
598- return get_comm_module ().init_custom_ar (ipc_tensors , rank_data , rank , full_nvlink )
614+ return get_vllm_comm_module ().init_custom_ar (
615+ ipc_tensors , rank_data , rank , full_nvlink
616+ )
599617
600618
601619def dispose (fa : int ) -> None :
602- get_comm_module ().dispose (fa )
620+ get_vllm_comm_module ().dispose (fa )
603621
604622
605623def all_reduce (
@@ -621,27 +639,27 @@ def all_reduce(
621639 num_ctas: The number of CTAs to use for the all reduce.
622640 CTA upper bounds: 36. Generally, we can saturate the bandwidth even with small amount the SMs.
623641 """
624- get_comm_module ().all_reduce (
642+ get_vllm_comm_module ().all_reduce (
625643 fa , inp , out , reg_buffer , reg_buffer_sz_bytes , num_ctas
626644 )
627645
628646
629647def get_graph_buffer_ipc_meta (fa ) -> Tuple [List [int ], List [int ]]:
630- return get_comm_module ().get_graph_buffer_ipc_meta (fa )
648+ return get_vllm_comm_module ().get_graph_buffer_ipc_meta (fa )
631649
632650
633651def register_buffer (fa : int , fake_ipc_ptrs : List [int ]) -> None :
634- return get_comm_module ().register_buffer (fa , fake_ipc_ptrs )
652+ return get_vllm_comm_module ().register_buffer (fa , fake_ipc_ptrs )
635653
636654
637655def register_graph_buffers (
638656 fa : int , handles : List [List [int ]], offsets : List [List [int ]]
639657) -> None :
640- get_comm_module ().register_graph_buffers (fa , handles , offsets )
658+ get_vllm_comm_module ().register_graph_buffers (fa , handles , offsets )
641659
642660
643661def meta_size () -> int :
644- return get_comm_module ().meta_size ()
662+ return get_vllm_comm_module ().meta_size ()
645663
646664
647665def create_shared_buffer (
@@ -913,7 +931,7 @@ def pad_up(x, y):
913931
914932
915933def trtllm_lamport_initialize (buffer_ptr : int , size : int , dtype : torch .dtype ) -> None :
916- get_comm_module ().trtllm_lamport_initialize (buffer_ptr , size , dtype )
934+ get_trtllm_comm_module ().trtllm_lamport_initialize (buffer_ptr , size , dtype )
917935
918936
919937def trtllm_lamport_initialize_all (
@@ -923,7 +941,7 @@ def trtllm_lamport_initialize_all(
923941 size : int ,
924942 dtype : torch .dtype ,
925943) -> None :
926- get_comm_module ().trtllm_lamport_initialize_all (
944+ get_trtllm_comm_module ().trtllm_lamport_initialize_all (
927945 buffer_0_ptr , buffer_1_ptr , buffer_2_ptr , size , dtype
928946 )
929947
@@ -952,7 +970,7 @@ def trtllm_custom_all_reduce(
952970 lamport_peer_comm_buffer_ptrs_1 : Optional [torch .Tensor ],
953971 lamport_peer_comm_buffer_ptrs_2 : Optional [torch .Tensor ],
954972) -> None :
955- get_comm_module ().trtllm_custom_all_reduce (
973+ get_trtllm_comm_module ().trtllm_custom_all_reduce (
956974 inp ,
957975 out ,
958976 tp_size ,
@@ -1001,7 +1019,7 @@ def trtllm_allreduce_fusion(
10011019 scale_factor : Optional [float ],
10021020 layout_code : Optional [FP4QuantizationSFLayout ],
10031021) -> None :
1004- get_comm_module ().trtllm_allreduce_fusion (
1022+ get_trtllm_comm_module ().trtllm_allreduce_fusion (
10051023 allreduce_in = allreduce_in ,
10061024 world_size = world_size ,
10071025 world_rank = world_rank ,
@@ -1048,7 +1066,7 @@ def trtllm_moe_allreduce_fusion(
10481066 quant_out : Optional [torch .Tensor ],
10491067 scale_out : Optional [torch .Tensor ],
10501068) -> None :
1051- get_comm_module ().trtllm_moe_allreduce_fusion (
1069+ get_trtllm_comm_module ().trtllm_moe_allreduce_fusion (
10521070 world_size = world_size ,
10531071 world_rank = world_rank ,
10541072 token_num = token_num ,
0 commit comments