Skip to content

Commit 1d353b6

Browse files
authored
[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 3496274 commit 1d353b6

File tree

5 files changed

+32
-65
lines changed

5 files changed

+32
-65
lines changed

benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def benchmark_decode(
110110
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
111111
workspace_buffer,
112112
kv_layout,
113-
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
113+
use_tensor_cores=True,
114114
)
115115
wrapper.plan(
116116
kv_indptr,

tests/kernels/attention/test_flashinfer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
137137
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
138138
wrapper = flashinfer.\
139139
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
140-
use_tensor_cores=(
141-
(num_query_heads//num_kv_heads) > 4)
142-
)
140+
use_tensor_cores=True)
143141
wrapper.plan(
144142
kv_indptr,
145143
kv_indices,
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
411409
assert num_query_heads % num_kv_heads == 0
412410
max_kv_len = max(kv_lens)
413411
scale = head_size**-0.5
414-
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
412+
use_tensor_cores = True
415413
kv_cache_dtype = torch.float8_e4m3fn
416414

417415
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
136136

137137
# Baseline Decode
138138
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
139-
workspace_buffer,
140-
kv_layout,
141-
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4))
139+
workspace_buffer, kv_layout, use_tensor_cores=True)
142140
wrapper.plan(kv_indptr,
143141
kv_indices,
144142
kv_last_page_lens,

vllm/envs.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
VLLM_TRACE_FUNCTION: int = 0
4343
VLLM_ATTENTION_BACKEND: Optional[str] = None
4444
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
45-
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
4645
VLLM_PP_LAYER_PARTITION: Optional[str] = None
4746
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
4847
VLLM_CPU_OMP_THREADS_BIND: str = ""
@@ -465,11 +464,6 @@ def get_vllm_port() -> Optional[int]:
465464
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
466465
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
467466

468-
# If set, vllm will force flashinfer to use tensor cores;
469-
# otherwise will use heuristic based on model architecture.
470-
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
471-
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
472-
473467
# Pipeline stage partition strategy
474468
"VLLM_PP_LAYER_PARTITION":
475469
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
@@ -1221,7 +1215,6 @@ def compute_hash() -> str:
12211215
"VLLM_USE_AITER_UNIFIED_ATTENTION",
12221216
"VLLM_ATTENTION_BACKEND",
12231217
"VLLM_USE_FLASHINFER_SAMPLER",
1224-
"VLLM_FLASHINFER_FORCE_TENSOR_CORES",
12251218
"VLLM_DISABLED_KERNELS",
12261219
"VLLM_USE_DEEP_GEMM",
12271220
"VLLM_USE_TRTLLM_FP4_GEMM",

vllm/v1/attention/backends/flashinfer.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
1414
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
1515

16-
import vllm.envs as envs
1716
from vllm import _custom_ops as ops
1817
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1918
AttentionType)
@@ -228,8 +227,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
228227
self.q_data_type = self.kv_cache_dtype
229228
else:
230229
self.kv_cache_dtype = self.kv_cache_spec.dtype
231-
self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or
232-
(self.num_qo_heads // self.num_kv_heads > 4))
233230

234231
self._cascade_wrapper = None # Wrapper for cascade attention
235232

@@ -308,7 +305,11 @@ def _get_decode_wrapper(self,
308305
paged_kv_indptr_buffer=paged_kv_indptr,
309306
paged_kv_indices_buffer=paged_kv_indices,
310307
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
311-
use_tensor_cores=self.use_tensor_cores)
308+
# Tensor cores are enabled by default because the perf would be
309+
# atleast as good as cuda cores for all attention ops in latest
310+
# gpus.
311+
use_tensor_cores=True,
312+
)
312313

313314
# save the decode wrapper
314315
if use_cudagraph:
@@ -984,52 +985,29 @@ def fast_plan_decode(
984985
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
985986
non_blocking=True)
986987

987-
if self.use_tensor_cores:
988-
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
989-
990-
try:
991-
# Make sure we pass exactly 15 arguments for tensor core version
992-
self._plan_info = self._cached_module.plan(
993-
self._float_workspace_buffer,
994-
self._int_workspace_buffer,
995-
self._pin_memory_int_workspace_buffer,
996-
qo_indptr_host,
997-
indptr_cpu,
998-
seq_lens_cpu,
999-
batch_size, # total_num_rows
1000-
batch_size,
1001-
num_qo_heads,
1002-
num_kv_heads,
1003-
page_size,
1004-
self.is_cuda_graph_enabled,
1005-
head_dim,
1006-
head_dim,
1007-
False, # causal
1008-
)
1009-
except Exception as e:
1010-
raise RuntimeError(f"Error in tensor core plan: {e}") from e
1011-
else:
1012-
try:
1013-
# Make sure we pass exactly 15 arguments for standard version
1014-
self._plan_info = self._cached_module.plan(
1015-
self._float_workspace_buffer,
1016-
self._int_workspace_buffer,
1017-
self._pin_memory_int_workspace_buffer,
1018-
indptr_cpu,
1019-
batch_size,
1020-
num_qo_heads,
1021-
num_kv_heads,
1022-
page_size,
1023-
self.is_cuda_graph_enabled,
1024-
window_left,
1025-
logits_soft_cap,
1026-
head_dim,
1027-
head_dim,
1028-
torch.empty(0, dtype=q_data_type),
1029-
torch.empty(0, dtype=kv_data_type),
1030-
)
1031-
except Exception as e:
1032-
raise RuntimeError(f"Error in standard plan: {e}") from e
988+
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
989+
990+
try:
991+
# Make sure we pass exactly 15 arguments for tensor core version
992+
self._plan_info = self._cached_module.plan(
993+
self._float_workspace_buffer,
994+
self._int_workspace_buffer,
995+
self._pin_memory_int_workspace_buffer,
996+
qo_indptr_host,
997+
indptr_cpu,
998+
seq_lens_cpu,
999+
batch_size, # total_num_rows
1000+
batch_size,
1001+
num_qo_heads,
1002+
num_kv_heads,
1003+
page_size,
1004+
self.is_cuda_graph_enabled,
1005+
head_dim,
1006+
head_dim,
1007+
False, # causal
1008+
)
1009+
except Exception as e:
1010+
raise RuntimeError(f"Error in tensor core plan: {e}") from e
10331011

10341012
self._pos_encoding_mode = pos_encoding_mode
10351013
self._window_left = window_left

0 commit comments

Comments
 (0)