|
13 | 13 | from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache |
14 | 14 | from flashinfer.prefill import trtllm_batch_context_with_kv_cache |
15 | 15 |
|
16 | | -import vllm.envs as envs |
17 | 16 | from vllm import _custom_ops as ops |
18 | 17 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
19 | 18 | AttentionType) |
@@ -228,8 +227,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], |
228 | 227 | self.q_data_type = self.kv_cache_dtype |
229 | 228 | else: |
230 | 229 | self.kv_cache_dtype = self.kv_cache_spec.dtype |
231 | | - self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or |
232 | | - (self.num_qo_heads // self.num_kv_heads > 4)) |
233 | 230 |
|
234 | 231 | self._cascade_wrapper = None # Wrapper for cascade attention |
235 | 232 |
|
@@ -308,7 +305,11 @@ def _get_decode_wrapper(self, |
308 | 305 | paged_kv_indptr_buffer=paged_kv_indptr, |
309 | 306 | paged_kv_indices_buffer=paged_kv_indices, |
310 | 307 | paged_kv_last_page_len_buffer=paged_kv_last_page_len, |
311 | | - use_tensor_cores=self.use_tensor_cores) |
| 308 | + # Tensor cores are enabled by default because the perf would be |
| 309 | + # atleast as good as cuda cores for all attention ops in latest |
| 310 | + # gpus. |
| 311 | + use_tensor_cores=True, |
| 312 | + ) |
312 | 313 |
|
313 | 314 | # save the decode wrapper |
314 | 315 | if use_cudagraph: |
@@ -984,52 +985,29 @@ def fast_plan_decode( |
984 | 985 | self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, |
985 | 986 | non_blocking=True) |
986 | 987 |
|
987 | | - if self.use_tensor_cores: |
988 | | - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") |
989 | | - |
990 | | - try: |
991 | | - # Make sure we pass exactly 15 arguments for tensor core version |
992 | | - self._plan_info = self._cached_module.plan( |
993 | | - self._float_workspace_buffer, |
994 | | - self._int_workspace_buffer, |
995 | | - self._pin_memory_int_workspace_buffer, |
996 | | - qo_indptr_host, |
997 | | - indptr_cpu, |
998 | | - seq_lens_cpu, |
999 | | - batch_size, # total_num_rows |
1000 | | - batch_size, |
1001 | | - num_qo_heads, |
1002 | | - num_kv_heads, |
1003 | | - page_size, |
1004 | | - self.is_cuda_graph_enabled, |
1005 | | - head_dim, |
1006 | | - head_dim, |
1007 | | - False, # causal |
1008 | | - ) |
1009 | | - except Exception as e: |
1010 | | - raise RuntimeError(f"Error in tensor core plan: {e}") from e |
1011 | | - else: |
1012 | | - try: |
1013 | | - # Make sure we pass exactly 15 arguments for standard version |
1014 | | - self._plan_info = self._cached_module.plan( |
1015 | | - self._float_workspace_buffer, |
1016 | | - self._int_workspace_buffer, |
1017 | | - self._pin_memory_int_workspace_buffer, |
1018 | | - indptr_cpu, |
1019 | | - batch_size, |
1020 | | - num_qo_heads, |
1021 | | - num_kv_heads, |
1022 | | - page_size, |
1023 | | - self.is_cuda_graph_enabled, |
1024 | | - window_left, |
1025 | | - logits_soft_cap, |
1026 | | - head_dim, |
1027 | | - head_dim, |
1028 | | - torch.empty(0, dtype=q_data_type), |
1029 | | - torch.empty(0, dtype=kv_data_type), |
1030 | | - ) |
1031 | | - except Exception as e: |
1032 | | - raise RuntimeError(f"Error in standard plan: {e}") from e |
| 988 | + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") |
| 989 | + |
| 990 | + try: |
| 991 | + # Make sure we pass exactly 15 arguments for tensor core version |
| 992 | + self._plan_info = self._cached_module.plan( |
| 993 | + self._float_workspace_buffer, |
| 994 | + self._int_workspace_buffer, |
| 995 | + self._pin_memory_int_workspace_buffer, |
| 996 | + qo_indptr_host, |
| 997 | + indptr_cpu, |
| 998 | + seq_lens_cpu, |
| 999 | + batch_size, # total_num_rows |
| 1000 | + batch_size, |
| 1001 | + num_qo_heads, |
| 1002 | + num_kv_heads, |
| 1003 | + page_size, |
| 1004 | + self.is_cuda_graph_enabled, |
| 1005 | + head_dim, |
| 1006 | + head_dim, |
| 1007 | + False, # causal |
| 1008 | + ) |
| 1009 | + except Exception as e: |
| 1010 | + raise RuntimeError(f"Error in tensor core plan: {e}") from e |
1033 | 1011 |
|
1034 | 1012 | self._pos_encoding_mode = pos_encoding_mode |
1035 | 1013 | self._window_left = window_left |
|
0 commit comments