Skip to content

Commit 1eef0f8

Browse files
authored
Update to tuned 2d kernel, add some mamba kernels (#26)
This PR adds the tunable 2d attention kernel (similar to vllm-project/vllm#20690) with tuning using the micro-benchmarks already carried out for H100 and MI300. --------- Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
1 parent 3422877 commit 1eef0f8

File tree

27 files changed

+3713
-217
lines changed
  • ibm-triton-lib/ibm_triton_lib
    • backend
    • kernels
      • dejavu_data/dejavu_0.7/triton_3.3.0
        • cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3
          • _selective_scan_update_kernel/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default
          • kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227
            • code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default
            • code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default
            • code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default
        • rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X
          • _selective_scan_update_kernel/autotune_config-90178d0ab8e71db9cd16710d562763dd010643f28cd21980d5064c3ab782ecaa/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default
          • kernel_unified_attention_2d
            • autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227
              • code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default
              • code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default
            • autotune_config-eff99677f7c0c1715ee99c9f1c8cf2a597630dd934ea82c3a3f4cdcd26d2e859/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default
    • utils
  • scripts

27 files changed

+3713
-217
lines changed

Dockerfile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,13 @@ RUN --mount=type=cache,target=/root/.cache/pip \
240240
--mount=type=cache,target=/root/.cache/uv \
241241
git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness && cd lm-evaluation-harness && uv pip install .
242242

243+
RUN git clone --depth 1 https://github.com/IBM/fmwork.git
244+
243245
ENV STORE_TEST_RESULT_PATH=/results
244246

245-
# copy vllm benchmarks
247+
# copy vllm benchmarks and tests
246248
COPY vllm/benchmarks benchmarks
249+
COPY vllm/tests tests
247250
COPY ShareGPT_V3_unfiltered_cleaned_split.json ShareGPT_V3_unfiltered_cleaned_split.json
248251

249252
# Copy thid-party kernels and insert into path

ibm-triton-lib/ibm_triton_lib/backend/platform.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@
2727
import vllm.envs as envs
2828
from vllm.logger import init_logger
2929

30+
3031
from vllm.platforms import Platform, PlatformEnum
31-
from vllm.platforms.cuda import CudaPlatform
32+
33+
if not torch.version.hip:
34+
from vllm.platforms.cuda import CudaPlatform
35+
else:
36+
from vllm.platforms.rocm import RocmPlatform
3237

3338

3439
from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum, _Backend
@@ -46,21 +51,41 @@
4651
torch.backends.cuda.enable_cudnn_sdp(False)
4752

4853

49-
# CudaPlatform is a constant, not a class, but it dynamically decdes between Nvml and NonNVML class
50-
# so we should inherit from this
51-
class TritonPlatform(CudaPlatform):
52-
53-
@classmethod
54-
def get_attn_backend_cls(
55-
cls,
56-
selected_backend,
57-
head_size,
58-
dtype,
59-
kv_cache_dtype,
60-
block_size,
61-
use_v1,
62-
use_mla,
63-
) -> str:
64-
if not envs.VLLM_USE_V1:
65-
raise RuntimeError("vllm-triton-backend plugin only supports vLLM V1")
66-
return "ibm_triton_lib.backend.triton_attn.TritonAttentionBackend"
54+
if not torch.version.hip:
55+
# CudaPlatform is a constant, not a class, but it dynamically decdes between Nvml and NonNVML class
56+
# so we should inherit from this
57+
class TritonPlatform(CudaPlatform):
58+
59+
@classmethod
60+
def get_attn_backend_cls(
61+
cls,
62+
selected_backend,
63+
head_size,
64+
dtype,
65+
kv_cache_dtype,
66+
block_size,
67+
use_v1,
68+
use_mla,
69+
) -> str:
70+
if not envs.VLLM_USE_V1:
71+
raise RuntimeError("vllm-triton-backend plugin only supports vLLM V1")
72+
return "ibm_triton_lib.backend.triton_attn.TritonAttentionBackend"
73+
74+
else:
75+
76+
class TritonPlatform(RocmPlatform):
77+
78+
@classmethod
79+
def get_attn_backend_cls(
80+
cls,
81+
selected_backend,
82+
head_size,
83+
dtype,
84+
kv_cache_dtype,
85+
block_size,
86+
use_v1,
87+
use_mla,
88+
) -> str:
89+
if not envs.VLLM_USE_V1:
90+
raise RuntimeError("vllm-triton-backend plugin only supports vLLM V1")
91+
return "ibm_triton_lib.backend.triton_attn.TritonAttentionBackend"

ibm-triton-lib/ibm_triton_lib/backend/triton_attn.py

Lines changed: 68 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,9 @@
4040
AttentionMetadata,
4141
AttentionType,
4242
)
43-
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
44-
from vllm.attention.ops.paged_attn import PagedAttention
4543
from ibm_triton_lib.kernels import unified_attention
4644
from vllm.logger import init_logger
4745
from vllm.platforms import current_platform
48-
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
4946
from vllm.v1.attention.backends.utils import (
5047
AttentionMetadataBuilder,
5148
CommonAttentionMetadata,
@@ -72,6 +69,8 @@ class TritonAttentionMetadata:
7269

7370
num_actual_tokens: int # Number of tokens excluding padding.
7471
max_query_len: int
72+
avg_query_len: int
73+
avg_seq_len: int
7574
query_start_loc: torch.Tensor
7675
max_seq_len: int
7776
seq_lens: torch.Tensor
@@ -97,6 +96,8 @@ class LocalAttentionMetadata:
9796
local_block_table: torch.Tensor
9897
local_max_query_len: int
9998
local_max_seq_len: int
99+
local_avg_query_len: int
100+
local_avg_seq_len: int
100101
local_scheduler_metadata: Optional[torch.Tensor]
101102

102103
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@@ -139,6 +140,9 @@ def build(
139140
block_table = self.block_table
140141
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
141142

143+
avg_seq_len = int(self.runner.seq_lens_np[:num_reqs].mean())
144+
avg_query_len = int(self.runner.query_start_loc_np[num_reqs] / num_reqs)
145+
142146
block_table.slot_mapping[:num_actual_tokens].copy_(
143147
block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True
144148
)
@@ -170,14 +174,18 @@ def build(
170174
self.runner.device, non_blocking=True
171175
)
172176
local_max_query_len = seqlens_q_local_np.max()
177+
local_avg_query_len = int(seqlens_q_local_np[num_reqs] / num_reqs)
173178
local_max_seq_len = virt_k_seqlens_np.max()
179+
local_avg_seq_len = int(virt_k_seqlens_np[num_reqs] / num_reqs)
174180

175181
local_attn_metadata = TritonAttentionMetadata.LocalAttentionMetadata(
176182
local_query_start_loc=local_query_start_loc,
177183
local_seqused_k=local_seqused_k,
178184
local_block_table=virt_block_table_tensor,
179185
local_max_query_len=local_max_query_len,
180186
local_max_seq_len=local_max_seq_len,
187+
local_avg_query_len=local_avg_query_len,
188+
local_avg_seq_len=local_avg_seq_len,
181189
local_scheduler_metadata=None,
182190
)
183191

@@ -213,6 +221,8 @@ def build(
213221
suffix_kv_lens=suffix_kv_lens,
214222
local_attn_metadata=local_attn_metadata,
215223
prefix_scheduler_metadata=prefix_scheduler_metadata,
224+
avg_query_len=avg_query_len,
225+
avg_seq_len=avg_seq_len,
216226
)
217227
return attn_metadata
218228

@@ -227,10 +237,22 @@ class TritonAttentionBackend(AttentionBackend):
227237

228238
accept_output_buffer: bool = True
229239

230-
@staticmethod
231-
def get_supported_head_sizes() -> list[int]:
240+
@classmethod
241+
def get_supported_head_sizes(cls) -> list[int]:
232242
return [32, 64, 96, 128, 160, 192, 224, 256]
233243

244+
@classmethod
245+
def validate_head_size(cls, head_size: int) -> None:
246+
supported_head_sizes = cls.get_supported_head_sizes()
247+
if head_size not in supported_head_sizes:
248+
attn_type = cls.__name__.removesuffix("Backend")
249+
raise ValueError(
250+
f"Head size {head_size} is not supported by {attn_type}. "
251+
f"Supported head sizes are: {supported_head_sizes}. "
252+
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
253+
"FlexAttention backend which supports all head sizes."
254+
)
255+
234256
@staticmethod
235257
def get_name() -> str:
236258
return "TRITON_ATTN_VLLM_V1"
@@ -304,12 +326,7 @@ def __init__(
304326

305327
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
306328

307-
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
308-
if head_size not in support_head_sizes:
309-
raise ValueError(
310-
f"Head size {head_size} is not supported by TritonAttention. "
311-
f"Supported head sizes are: {support_head_sizes}."
312-
)
329+
TritonAttentionBackend.validate_head_size(head_size)
313330

314331
if attn_type != AttentionType.DECODER:
315332
raise NotImplementedError(
@@ -331,7 +348,7 @@ def forward(
331348
key: torch.Tensor,
332349
value: torch.Tensor,
333350
kv_cache: torch.Tensor,
334-
attn_metadata: FlashAttentionMetadata,
351+
attn_metadata: TritonAttentionMetadata,
335352
output: Optional[torch.Tensor] = None,
336353
output_scale: Optional[torch.Tensor] = None,
337354
) -> torch.Tensor:
@@ -369,41 +386,23 @@ def forward(
369386
# Whenever making a change in this method, please benchmark the
370387
# performance to make sure it does not introduce any overhead.
371388

372-
use_prefill_decode_attn = self.force_prefill_decode_attn
373389
num_actual_tokens = attn_metadata.num_actual_tokens
374390

375-
if use_prefill_decode_attn:
376-
key_cache, value_cache = PagedAttention.split_kv_cache(
377-
kv_cache, self.num_kv_heads, self.head_size
378-
)
379-
else:
380-
key_cache, value_cache = kv_cache.unbind(0)
391+
key_cache, value_cache = kv_cache.unbind(0)
381392

382393
if self.kv_sharing_target_layer_name is None:
383394
# Reshape the input keys and values and store them in the cache.
384395
# Skip this if sharing KV cache with an earlier attention layer.
385-
if use_prefill_decode_attn:
386-
PagedAttention.write_to_paged_cache(
387-
key,
388-
value,
389-
key_cache,
390-
value_cache,
391-
attn_metadata.slot_mapping,
392-
self.kv_cache_dtype,
393-
layer._k_scale,
394-
layer._v_scale,
395-
)
396-
else:
397-
torch.ops._C_cache_ops.reshape_and_cache_flash(
398-
key,
399-
value,
400-
key_cache,
401-
value_cache,
402-
attn_metadata.slot_mapping,
403-
self.kv_cache_dtype,
404-
layer._k_scale,
405-
layer._v_scale,
406-
)
396+
torch.ops._C_cache_ops.reshape_and_cache_flash(
397+
key,
398+
value,
399+
key_cache,
400+
value_cache,
401+
attn_metadata.slot_mapping,
402+
self.kv_cache_dtype,
403+
layer._k_scale,
404+
layer._v_scale,
405+
)
407406

408407
if self.kv_cache_dtype.startswith("fp8"):
409408
key_cache = key_cache.view(self.fp8_dtype)
@@ -433,56 +432,39 @@ def forward(
433432
max_seqlen_q = local_metadata.local_max_query_len
434433
max_seqlen_k = local_metadata.local_max_seq_len
435434
block_table = local_metadata.local_block_table
435+
avg_seqlen_q = local_metadata.local_avg_query_len
436+
avg_seqlen_k = local_metadata.local_avg_seq_len
436437
else:
437438
cu_seqlens_q = attn_metadata.query_start_loc
438439
seqused_k = attn_metadata.seq_lens
439440
max_seqlen_q = attn_metadata.max_query_len
440441
max_seqlen_k = attn_metadata.max_seq_len
441442
block_table = attn_metadata.block_table
442-
443-
if use_prefill_decode_attn:
444-
# Compute attention and update output up to `num_actual_tokens`.
445-
chunked_prefill_paged_decode(
446-
query=query[:num_actual_tokens],
447-
key=key[:num_actual_tokens],
448-
value=value[:num_actual_tokens],
449-
output=output[:num_actual_tokens],
450-
kv_cache_dtype=self.kv_cache_dtype,
451-
key_cache=key_cache,
452-
value_cache=value_cache,
453-
block_table=block_table,
454-
query_start_loc=cu_seqlens_q,
455-
seq_lens=seqused_k,
456-
max_seq_len=max_seqlen_k,
457-
max_query_len=max_seqlen_q,
458-
k_scale=layer._k_scale,
459-
v_scale=layer._v_scale,
460-
alibi_slopes=self.alibi_slopes,
461-
sliding_window=self.sliding_window[0],
462-
sm_scale=self.scale,
463-
)
464-
465-
else:
466-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
467-
468-
unified_attention(
469-
q=query[:num_actual_tokens],
470-
k=key_cache,
471-
v=value_cache,
472-
out=output[:num_actual_tokens],
473-
cu_seqlens_q=cu_seqlens_q,
474-
max_seqlen_q=max_seqlen_q,
475-
seqused_k=seqused_k,
476-
max_seqlen_k=max_seqlen_k,
477-
softmax_scale=self.scale,
478-
causal=True,
479-
alibi_slopes=self.alibi_slopes,
480-
window_size=self.sliding_window,
481-
block_table=block_table,
482-
softcap=self.logits_soft_cap,
483-
q_descale=None, # Not supported
484-
k_descale=layer._k_scale.expand(descale_shape),
485-
v_descale=layer._v_scale.expand(descale_shape),
486-
)
443+
avg_seqlen_q = attn_metadata.avg_query_len
444+
avg_seqlen_k = attn_metadata.avg_seq_len
445+
446+
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
447+
448+
unified_attention(
449+
q=query[:num_actual_tokens],
450+
k=key_cache,
451+
v=value_cache,
452+
out=output[:num_actual_tokens],
453+
cu_seqlens_q=cu_seqlens_q,
454+
max_seqlen_q=max_seqlen_q,
455+
seqused_k=seqused_k,
456+
max_seqlen_k=max_seqlen_k,
457+
avg_seqlen_q=avg_seqlen_q,
458+
avg_seqlen_k=avg_seqlen_k,
459+
softmax_scale=self.scale,
460+
causal=True,
461+
alibi_slopes=self.alibi_slopes,
462+
window_size=self.sliding_window,
463+
block_table=block_table,
464+
softcap=self.logits_soft_cap,
465+
q_descale=None, # Not supported
466+
k_descale=layer._k_scale.expand(descale_shape),
467+
v_descale=layer._v_scale.expand(descale_shape),
468+
)
487469

488470
return output

ibm-triton-lib/ibm_triton_lib/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,5 @@ def ConfigSpace(
6767
)
6868

6969
from .triton_unified_attention import unified_attention
70+
71+
from .mamba_ssm import selective_state_update
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"signature": "JITFunction(ibm_triton_lib.kernels.mamba_ssm:_selective_scan_update_kernel)",
3+
"total_bench_time_s": 58.42541313171387,
4+
"evaluated_configs": 75,
5+
"keys": [
6+
"dstate",
7+
"BLOCK_SIZE_DSTATE",
8+
"dim",
9+
"nheads_ngroups_ratio"
10+
],
11+
"cache": {
12+
"('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE_M: 8, num_warps: 2, num_ctas: 1, num_stages: 6, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None"
13+
},
14+
"timings": {
15+
"('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": [
16+
0.003274054965004325
17+
]
18+
},
19+
"timings_data": {
20+
"labels": [
21+
"ms"
22+
],
23+
"rep_t_ms": 100,
24+
"warmup_t_ms": 25,
25+
"cuda_graphs": true
26+
}
27+
}

0 commit comments

Comments
 (0)