-
Notifications
You must be signed in to change notification settings - Fork 580
refactor: communication module #1162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -264,12 +264,20 @@ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: | |||||||||||||||||||||||||||||||||||||||||
| cudart = CudaRTLibrary() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def gen_comm_module() -> JitSpec: | ||||||||||||||||||||||||||||||||||||||||||
| def gen_vllm_comm_module() -> JitSpec: | ||||||||||||||||||||||||||||||||||||||||||
| return gen_jit_spec( | ||||||||||||||||||||||||||||||||||||||||||
| "comm", | ||||||||||||||||||||||||||||||||||||||||||
| "vllm_comm", | ||||||||||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||||||||
| jit_env.FLASHINFER_CSRC_DIR / "comm_pybind.cu", | ||||||||||||||||||||||||||||||||||||||||||
| jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu", | ||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def gen_trtllm_comm_module() -> JitSpec: | ||||||||||||||||||||||||||||||||||||||||||
| return gen_jit_spec( | ||||||||||||||||||||||||||||||||||||||||||
| "trtllm_comm", | ||||||||||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||||||||
| jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu", | ||||||||||||||||||||||||||||||||||||||||||
| jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu", | ||||||||||||||||||||||||||||||||||||||||||
| jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu", | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -279,8 +287,8 @@ def gen_comm_module() -> JitSpec: | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @functools.cache | ||||||||||||||||||||||||||||||||||||||||||
| def get_comm_module(): | ||||||||||||||||||||||||||||||||||||||||||
| module = gen_comm_module().build_and_load() | ||||||||||||||||||||||||||||||||||||||||||
| def get_vllm_comm_module(): | ||||||||||||||||||||||||||||||||||||||||||
| module = gen_vllm_comm_module().build_and_load() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # torch library for all | ||||||||||||||||||||||||||||||||||||||||||
| @register_custom_op( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -333,6 +341,21 @@ def all_reduce( | |||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| module.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes, num_ctas) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return SimpleNamespace( | ||||||||||||||||||||||||||||||||||||||||||
| init_custom_ar=init_custom_ar, | ||||||||||||||||||||||||||||||||||||||||||
| dispose=dispose, | ||||||||||||||||||||||||||||||||||||||||||
| get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta, | ||||||||||||||||||||||||||||||||||||||||||
| register_buffer=register_buffer, | ||||||||||||||||||||||||||||||||||||||||||
| register_graph_buffers=register_graph_buffers, | ||||||||||||||||||||||||||||||||||||||||||
| meta_size=meta_size, | ||||||||||||||||||||||||||||||||||||||||||
| all_reduce=all_reduce, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+344
to
+352
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider extracting the common arguments passed to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @functools.cache | ||||||||||||||||||||||||||||||||||||||||||
| def get_trtllm_comm_module(): | ||||||||||||||||||||||||||||||||||||||||||
| module = gen_trtllm_comm_module().build_and_load() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @register_custom_op( | ||||||||||||||||||||||||||||||||||||||||||
| "flashinfer::trtllm_lamport_initialize", mutates_args=["buffer"] | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -577,13 +600,6 @@ def trtllm_moe_allreduce_fusion( | |||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return SimpleNamespace( | ||||||||||||||||||||||||||||||||||||||||||
| init_custom_ar=init_custom_ar, | ||||||||||||||||||||||||||||||||||||||||||
| dispose=dispose, | ||||||||||||||||||||||||||||||||||||||||||
| get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta, | ||||||||||||||||||||||||||||||||||||||||||
| register_buffer=register_buffer, | ||||||||||||||||||||||||||||||||||||||||||
| register_graph_buffers=register_graph_buffers, | ||||||||||||||||||||||||||||||||||||||||||
| meta_size=meta_size, | ||||||||||||||||||||||||||||||||||||||||||
| all_reduce=all_reduce, | ||||||||||||||||||||||||||||||||||||||||||
| trtllm_lamport_initialize=trtllm_lamport_initialize, | ||||||||||||||||||||||||||||||||||||||||||
| trtllm_lamport_initialize_all=trtllm_lamport_initialize_all, | ||||||||||||||||||||||||||||||||||||||||||
| trtllm_custom_all_reduce=trtllm_custom_all_reduce, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -595,11 +611,13 @@ def trtllm_moe_allreduce_fusion( | |||||||||||||||||||||||||||||||||||||||||
| def init_custom_ar( | ||||||||||||||||||||||||||||||||||||||||||
| ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool | ||||||||||||||||||||||||||||||||||||||||||
| ) -> int: | ||||||||||||||||||||||||||||||||||||||||||
| return get_comm_module().init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink) | ||||||||||||||||||||||||||||||||||||||||||
| return get_vllm_comm_module().init_custom_ar( | ||||||||||||||||||||||||||||||||||||||||||
| ipc_tensors, rank_data, rank, full_nvlink | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+614
to
+616
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using named arguments for improved readability, e.g., return get_vllm_comm_module().init_custom_ar(ipc_tensors=ipc_tensors, rank_data=rank_data, rank=rank, full_nvlink=full_nvlink) |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def dispose(fa: int) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().dispose(fa) | ||||||||||||||||||||||||||||||||||||||||||
| get_vllm_comm_module().dispose(fa) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def all_reduce( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -621,27 +639,27 @@ def all_reduce( | |||||||||||||||||||||||||||||||||||||||||
| num_ctas: The number of CTAs to use for the all reduce. | ||||||||||||||||||||||||||||||||||||||||||
| CTA upper bounds: 36. Generally, we can saturate the bandwidth even with small amount the SMs. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().all_reduce( | ||||||||||||||||||||||||||||||||||||||||||
| get_vllm_comm_module().all_reduce( | ||||||||||||||||||||||||||||||||||||||||||
| fa, inp, out, reg_buffer, reg_buffer_sz_bytes, num_ctas | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
639
to
644
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using descriptive variable names for get_vllm_comm_module().all_reduce(
registered_allocator, input_tensor, output_tensor, registered_buffer, registered_buffer_size_bytes, num_ctas
) |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]: | ||||||||||||||||||||||||||||||||||||||||||
| return get_comm_module().get_graph_buffer_ipc_meta(fa) | ||||||||||||||||||||||||||||||||||||||||||
| return get_vllm_comm_module().get_graph_buffer_ipc_meta(fa) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| return get_comm_module().register_buffer(fa, fake_ipc_ptrs) | ||||||||||||||||||||||||||||||||||||||||||
| return get_vllm_comm_module().register_buffer(fa, fake_ipc_ptrs) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def register_graph_buffers( | ||||||||||||||||||||||||||||||||||||||||||
| fa: int, handles: List[List[int]], offsets: List[List[int]] | ||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().register_graph_buffers(fa, handles, offsets) | ||||||||||||||||||||||||||||||||||||||||||
| get_vllm_comm_module().register_graph_buffers(fa, handles, offsets) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def meta_size() -> int: | ||||||||||||||||||||||||||||||||||||||||||
| return get_comm_module().meta_size() | ||||||||||||||||||||||||||||||||||||||||||
| return get_vllm_comm_module().meta_size() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def create_shared_buffer( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -913,7 +931,7 @@ def pad_up(x, y): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def trtllm_lamport_initialize(buffer_ptr: int, size: int, dtype: torch.dtype) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype) | ||||||||||||||||||||||||||||||||||||||||||
| get_trtllm_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def trtllm_lamport_initialize_all( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -923,7 +941,7 @@ def trtllm_lamport_initialize_all( | |||||||||||||||||||||||||||||||||||||||||
| size: int, | ||||||||||||||||||||||||||||||||||||||||||
| dtype: torch.dtype, | ||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().trtllm_lamport_initialize_all( | ||||||||||||||||||||||||||||||||||||||||||
| get_trtllm_comm_module().trtllm_lamport_initialize_all( | ||||||||||||||||||||||||||||||||||||||||||
| buffer_0_ptr, buffer_1_ptr, buffer_2_ptr, size, dtype | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -952,7 +970,7 @@ def trtllm_custom_all_reduce( | |||||||||||||||||||||||||||||||||||||||||
| lamport_peer_comm_buffer_ptrs_1: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||
| lamport_peer_comm_buffer_ptrs_2: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().trtllm_custom_all_reduce( | ||||||||||||||||||||||||||||||||||||||||||
| get_trtllm_comm_module().trtllm_custom_all_reduce( | ||||||||||||||||||||||||||||||||||||||||||
| inp, | ||||||||||||||||||||||||||||||||||||||||||
| out, | ||||||||||||||||||||||||||||||||||||||||||
| tp_size, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1001,7 +1019,7 @@ def trtllm_allreduce_fusion( | |||||||||||||||||||||||||||||||||||||||||
| scale_factor: Optional[float], | ||||||||||||||||||||||||||||||||||||||||||
| layout_code: Optional[FP4QuantizationSFLayout], | ||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().trtllm_allreduce_fusion( | ||||||||||||||||||||||||||||||||||||||||||
| get_trtllm_comm_module().trtllm_allreduce_fusion( | ||||||||||||||||||||||||||||||||||||||||||
| allreduce_in=allreduce_in, | ||||||||||||||||||||||||||||||||||||||||||
| world_size=world_size, | ||||||||||||||||||||||||||||||||||||||||||
| world_rank=world_rank, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1048,7 +1066,7 @@ def trtllm_moe_allreduce_fusion( | |||||||||||||||||||||||||||||||||||||||||
| quant_out: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||
| scale_out: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| get_comm_module().trtllm_moe_allreduce_fusion( | ||||||||||||||||||||||||||||||||||||||||||
| get_trtllm_comm_module().trtllm_moe_allreduce_fusion( | ||||||||||||||||||||||||||||||||||||||||||
| world_size=world_size, | ||||||||||||||||||||||||||||||||||||||||||
| world_rank=world_rank, | ||||||||||||||||||||||||||||||||||||||||||
| token_num=token_num, | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider renaming
gen_vllm_comm_moduletogen_comm_module_vllmfor better consistency withgen_trtllm_comm_module.