Skip to content

Commit 8736030

Browse files
authored
[V1] Use FlashInfer by default on Blackwell GPUs (#19118)
1 parent aa49f14 commit 8736030

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

vllm/platforms/cuda.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,21 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
229229
logger.info_once("Using Triton backend on V1 engine.")
230230
return ("vllm.v1.attention.backends."
231231
"triton_attn.TritonAttentionBackend")
232+
if cls.is_device_capability(100):
233+
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
234+
try:
235+
import flashinfer # noqa: F401
236+
logger.info_once(
237+
"Using FlashInfer backend on V1 engine by default for "
238+
"Blackwell (SM 10.0) GPUs.")
239+
return ("vllm.v1.attention.backends."
240+
"flashinfer.FlashInferBackend")
241+
except ImportError:
242+
logger.info_once(
243+
"FlashInfer failed to import for V1 engine on "
244+
"Blackwell (SM 10.0) GPUs; it is recommended to "
245+
"install FlashInfer for better performance.")
246+
pass
232247
if cls.has_device_capability(80):
233248
logger.info_once("Using Flash Attention backend on V1 engine.")
234249
return ("vllm.v1.attention.backends."

vllm/platforms/interface.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,30 @@ def has_device_capability(
228228

229229
return current_capability.to_int() >= capability
230230

231+
@classmethod
232+
def is_device_capability(
233+
cls,
234+
capability: Union[tuple[int, int], int],
235+
device_id: int = 0,
236+
) -> bool:
237+
"""
238+
Test whether this platform has exactly the specified device capability.
239+
240+
The `capability` argument can either be:
241+
242+
- A tuple `(major, minor)`.
243+
- An integer `<major><minor>`. (See
244+
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
245+
"""
246+
current_capability = cls.get_device_capability(device_id=device_id)
247+
if current_capability is None:
248+
return False
249+
250+
if isinstance(capability, tuple):
251+
return current_capability == capability
252+
253+
return current_capability.to_int() == capability
254+
231255
@classmethod
232256
def get_device_name(cls, device_id: int = 0) -> str:
233257
"""Get the name of a device."""

0 commit comments

Comments
 (0)