Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .activation import act_func_def_str, gen_act_and_mul_module
from .cascade import gen_cascade_module
from .comm import gen_comm_module
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
from .fp4_quantization import gen_fp4_quantization_sm100_module
from .fused_moe import gen_fused_moe_sm100_module
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
Expand Down Expand Up @@ -327,7 +327,8 @@ def gen_all_modules(

jit_specs += [
gen_cascade_module(),
gen_comm_module(),
gen_vllm_comm_module(),
gen_trtllm_comm_module(),
gen_norm_module(),
gen_page_module(),
gen_quantization_module(),
Expand Down
64 changes: 41 additions & 23 deletions flashinfer/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,20 @@ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudart = CudaRTLibrary()


def gen_comm_module() -> JitSpec:
def gen_vllm_comm_module() -> JitSpec:
return gen_jit_spec(
"comm",
"vllm_comm",
[
jit_env.FLASHINFER_CSRC_DIR / "comm_pybind.cu",
jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
],
Comment on lines +267 to +273
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming gen_vllm_comm_module to gen_comm_module_vllm for better consistency with gen_trtllm_comm_module.

Suggested change
def gen_vllm_comm_module() -> JitSpec:
return gen_jit_spec(
"comm",
"vllm_comm",
[
jit_env.FLASHINFER_CSRC_DIR / "comm_pybind.cu",
jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
],
def gen_comm_module_vllm() -> JitSpec:
return gen_jit_spec(
"vllm_comm",
[
jit_env.FLASHINFER_CSRC_DIR / "comm_pybind.cu",
jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu",
],
)

)


def gen_trtllm_comm_module() -> JitSpec:
return gen_jit_spec(
"trtllm_comm",
[
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu",
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu",
jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu",
Expand All @@ -279,8 +287,8 @@ def gen_comm_module() -> JitSpec:


@functools.cache
def get_comm_module():
module = gen_comm_module().build_and_load()
def get_vllm_comm_module():
module = gen_vllm_comm_module().build_and_load()

# torch library for all
@register_custom_op(
Expand Down Expand Up @@ -333,6 +341,21 @@ def all_reduce(
) -> None:
module.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes, num_ctas)

return SimpleNamespace(
init_custom_ar=init_custom_ar,
dispose=dispose,
get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta,
register_buffer=register_buffer,
register_graph_buffers=register_graph_buffers,
meta_size=meta_size,
all_reduce=all_reduce,
)
Comment on lines +344 to +352
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider extracting the common arguments passed to SimpleNamespace into a dictionary or a separate function to reduce redundancy and improve readability.

Suggested change
return SimpleNamespace(
init_custom_ar=init_custom_ar,
dispose=dispose,
get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta,
register_buffer=register_buffer,
register_graph_buffers=register_graph_buffers,
meta_size=meta_size,
all_reduce=all_reduce,
)
common_args = {
"init_custom_ar": init_custom_ar,
"dispose": dispose,
"get_graph_buffer_ipc_meta": get_graph_buffer_ipc_meta,
"register_buffer": register_buffer,
"register_graph_buffers": register_graph_buffers,
"meta_size": meta_size,
"all_reduce": all_reduce,
}
return SimpleNamespace(**common_args)



@functools.cache
def get_trtllm_comm_module():
module = gen_trtllm_comm_module().build_and_load()

@register_custom_op(
"flashinfer::trtllm_lamport_initialize", mutates_args=["buffer"]
)
Expand Down Expand Up @@ -577,13 +600,6 @@ def trtllm_moe_allreduce_fusion(
)

return SimpleNamespace(
init_custom_ar=init_custom_ar,
dispose=dispose,
get_graph_buffer_ipc_meta=get_graph_buffer_ipc_meta,
register_buffer=register_buffer,
register_graph_buffers=register_graph_buffers,
meta_size=meta_size,
all_reduce=all_reduce,
trtllm_lamport_initialize=trtllm_lamport_initialize,
trtllm_lamport_initialize_all=trtllm_lamport_initialize_all,
trtllm_custom_all_reduce=trtllm_custom_all_reduce,
Expand All @@ -595,11 +611,13 @@ def trtllm_moe_allreduce_fusion(
def init_custom_ar(
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
) -> int:
return get_comm_module().init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
return get_vllm_comm_module().init_custom_ar(
ipc_tensors, rank_data, rank, full_nvlink
)
Comment on lines +614 to +616
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using named arguments for improved readability, e.g., return get_vllm_comm_module().init_custom_ar(ipc_tensors=ipc_tensors, rank_data=rank_data, rank=rank, full_nvlink=full_nvlink).

return get_vllm_comm_module().init_custom_ar(ipc_tensors=ipc_tensors, rank_data=rank_data, rank=rank, full_nvlink=full_nvlink)



def dispose(fa: int) -> None:
get_comm_module().dispose(fa)
get_vllm_comm_module().dispose(fa)


def all_reduce(
Expand All @@ -621,27 +639,27 @@ def all_reduce(
num_ctas: The number of CTAs to use for the all reduce.
CTA upper bounds: 36. Generally, we can saturate the bandwidth even with small amount the SMs.
"""
get_comm_module().all_reduce(
get_vllm_comm_module().all_reduce(
fa, inp, out, reg_buffer, reg_buffer_sz_bytes, num_ctas
)
Comment on lines 639 to 644
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using descriptive variable names for fa, inp, out, reg_buffer, reg_buffer_sz_bytes, and num_ctas to improve code clarity. For example, registered_allocator, input_tensor, output_tensor, registered_buffer, registered_buffer_size_bytes, and number_of_ctas.

get_vllm_comm_module().all_reduce(
        registered_allocator, input_tensor, output_tensor, registered_buffer, registered_buffer_size_bytes, num_ctas
    )



def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
return get_comm_module().get_graph_buffer_ipc_meta(fa)
return get_vllm_comm_module().get_graph_buffer_ipc_meta(fa)


def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
return get_comm_module().register_buffer(fa, fake_ipc_ptrs)
return get_vllm_comm_module().register_buffer(fa, fake_ipc_ptrs)


def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
get_comm_module().register_graph_buffers(fa, handles, offsets)
get_vllm_comm_module().register_graph_buffers(fa, handles, offsets)


def meta_size() -> int:
return get_comm_module().meta_size()
return get_vllm_comm_module().meta_size()


def create_shared_buffer(
Expand Down Expand Up @@ -913,7 +931,7 @@ def pad_up(x, y):


def trtllm_lamport_initialize(buffer_ptr: int, size: int, dtype: torch.dtype) -> None:
get_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype)
get_trtllm_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype)


def trtllm_lamport_initialize_all(
Expand All @@ -923,7 +941,7 @@ def trtllm_lamport_initialize_all(
size: int,
dtype: torch.dtype,
) -> None:
get_comm_module().trtllm_lamport_initialize_all(
get_trtllm_comm_module().trtllm_lamport_initialize_all(
buffer_0_ptr, buffer_1_ptr, buffer_2_ptr, size, dtype
)

Expand Down Expand Up @@ -952,7 +970,7 @@ def trtllm_custom_all_reduce(
lamport_peer_comm_buffer_ptrs_1: Optional[torch.Tensor],
lamport_peer_comm_buffer_ptrs_2: Optional[torch.Tensor],
) -> None:
get_comm_module().trtllm_custom_all_reduce(
get_trtllm_comm_module().trtllm_custom_all_reduce(
inp,
out,
tp_size,
Expand Down Expand Up @@ -1001,7 +1019,7 @@ def trtllm_allreduce_fusion(
scale_factor: Optional[float],
layout_code: Optional[FP4QuantizationSFLayout],
) -> None:
get_comm_module().trtllm_allreduce_fusion(
get_trtllm_comm_module().trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
world_size=world_size,
world_rank=world_rank,
Expand Down Expand Up @@ -1048,7 +1066,7 @@ def trtllm_moe_allreduce_fusion(
quant_out: Optional[torch.Tensor],
scale_out: Optional[torch.Tensor],
) -> None:
get_comm_module().trtllm_moe_allreduce_fusion(
get_trtllm_comm_module().trtllm_moe_allreduce_fusion(
world_size=world_size,
world_rank=world_rank,
token_num=token_num,
Expand Down
File renamed without changes.