Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from .vllm_inductor_pass import VllmInductorPass

if find_spec("flashinfer"):
import flashinfer.comm as flashinfer_comm

flashinfer_comm = (flashinfer_comm if hasattr(
flashinfer_comm, "trtllm_allreduce_fusion") else None)
try:
import flashinfer.comm as flashinfer_comm
flashinfer_comm = (flashinfer_comm if hasattr(
flashinfer_comm, "trtllm_allreduce_fusion") else None)
except ImportError:
Comment on lines +23 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid shadowing the outer scope flashinfer_comm variable, it would be clearer to assign the module directly to comm and then conditionally assign it to flashinfer_comm if the attribute exists.

Suggested change
try:
import flashinfer.comm as flashinfer_comm
flashinfer_comm = (flashinfer_comm if hasattr(
flashinfer_comm, "trtllm_allreduce_fusion") else None)
except ImportError:
try:
import flashinfer.comm as comm
if hasattr(comm, "trtllm_allreduce_fusion"):
flashinfer_comm = comm
except ImportError:
flashinfer_comm = None

flashinfer_comm = None
else:
flashinfer_comm = None
from vllm.platforms import current_platform
Expand Down Expand Up @@ -411,7 +413,8 @@ def __init__(self, config: VllmConfig, max_token_num: int):
use_fp32_lamport = self.model_dtype == torch.float32
if flashinfer_comm is None:
logger.warning(
"Flashinfer is not installed, skipping allreduce fusion pass")
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass")
return
# Check if the world size is supported
if self.tp_size not in _FI_MAX_SIZES:
Expand Down