Skip to content

Commit b130f85

Browse files
fsx950223charlifu
andcommitted
Character ai (#587)
* integrate aiter Signed-off-by: fsx950223 <fsx950223@outlook.com> * add env variable Signed-off-by: fsx950223 <fsx950223@outlook.com> * rename function Signed-off-by: fsx950223 <fsx950223@outlook.com> * optimize kernels with small query lens Signed-off-by: fsx950223 <fsx950223@outlook.com> * change condition Signed-off-by: fsx950223 <fsx950223@outlook.com> * add rocm aiter backend Signed-off-by: fsx950223 <fsx950223@outlook.com> * new fa impl Signed-off-by: fsx950223 <fsx950223@outlook.com> * update api Signed-off-by: fsx950223 <fsx950223@outlook.com> * optimize performance Signed-off-by: fsx950223 <fsx950223@outlook.com> * remove try catch Signed-off-by: fsx950223 <fsx950223@outlook.com> * clean code Signed-off-by: fsx950223 <fsx950223@outlook.com> * remove type cast Signed-off-by: fsx950223 <fsx950223@outlook.com> * use on_gfx9 instead of on_mi250_mi300 Signed-off-by: charlifu <charlifu@amd.com> * add fp8 support Signed-off-by: fsx950223 <fsx950223@outlook.com> * revert layernorm Signed-off-by: fsx950223 <fsx950223@outlook.com> * enable aiter pa Signed-off-by: fsx950223 <fsx950223@outlook.com> * fix bug Signed-off-by: fsx950223 <fsx950223@outlook.com> * fix bug Signed-off-by: fsx950223 <fsx950223@outlook.com> * fix upstream issue Signed-off-by: fsx950223 <fsx950223@outlook.com> * change condition Signed-off-by: fsx950223 <fsx950223@outlook.com> * support head size 256 Signed-off-by: fsx950223 <fsx950223@outlook.com> * enable fp8 aiter pa in vllm v1 Signed-off-by: fsx950223 <fsx950223@outlook.com> * fix workspace buffer Signed-off-by: fsx950223 <fsx950223@outlook.com> * fix fa crash issue Signed-off-by: fsx950223 <fsx950223@outlook.com> * add namespace aiter Signed-off-by: fsx950223 <fsx950223@outlook.com> --------- Signed-off-by: fsx950223 <fsx950223@outlook.com> Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: charlifu <charlifu@amd.com> Signed-off-by: fsx950223 <fsx950223@outlook.com>
1 parent e28533a commit b130f85

File tree

4 files changed

+64
-50
lines changed

4 files changed

+64
-50
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -913,8 +913,7 @@ def forward(
913913
)
914914
max_logits = torch.empty_like(exp_sums)
915915

916-
query_start_loc = None
917-
ops.paged_attention_rocm(
916+
torch.ops.aiter.paged_attention_rocm(
918917
output[num_prefill_tokens:],
919918
exp_sums,
920919
max_logits,
@@ -930,7 +929,6 @@ def forward(
930929
decode_meta.seq_lens_tensor
931930
if self.attn_type != AttentionType.ENCODER_DECODER else
932931
decode_meta.encoder_seq_lens_tensor,
933-
query_start_loc,
934932
block_size,
935933
max_seq_len,
936934
self.alibi_slopes,

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111

12-
from vllm import _custom_ops as ops
1312
from vllm.platforms import current_platform
1413
from vllm.platforms.rocm import use_rocm_custom_paged_attention
1514
from vllm.triton_utils import tl, triton
@@ -305,7 +304,7 @@ def chunked_prefill_paged_decode(
305304
)
306305
max_logits = torch.empty_like(exp_sums)
307306

308-
ops.paged_attention_rocm(
307+
torch.ops.aiter.paged_attention_rocm(
309308
output,
310309
exp_sums,
311310
max_logits,
@@ -316,10 +315,9 @@ def chunked_prefill_paged_decode(
316315
num_kv_heads,
317316
scale=sm_scale,
318317
block_tables=block_table,
319-
seq_lens=seq_lens,
320-
query_start_loc=query_start_loc,
318+
context_lens=seq_lens,
321319
block_size=block_size,
322-
max_seq_len=max_seq_len,
320+
max_context_len=max_seq_len,
323321
alibi_slopes=alibi_slopes,
324322
kv_cache_dtype=kv_cache_dtype,
325323
k_scale=k_scale,

vllm/platforms/rocm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def use_rocm_custom_paged_attention(
138138
return ((not envs.VLLM_USE_V1 or sliding_window == 0
139139
or sliding_window == (-1, -1))
140140
and (qtype == torch.half or qtype == torch.bfloat16)
141-
and (head_size == 64 or head_size == 128)
141+
and (head_size in [64, 128, 256])
142142
and (block_size == 16 or block_size == 32)
143-
and (gqa_ratio >= 1 and gqa_ratio <= 16)
143+
and (gqa_ratio >= 1 and gqa_ratio <= 32)
144144
and max_seq_len <= 128 * 1024
145145
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
146146
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55

66
import torch
77

8-
from vllm import _custom_ops as ops
98
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10-
AttentionMetadata, AttentionType,
11-
is_quantized_kv_cache)
9+
AttentionMetadata, AttentionType)
1210
from vllm.logger import init_logger
1311
from vllm.platforms import current_platform
1412
from vllm.v1.attention.backends.flash_attn import (
@@ -17,6 +15,8 @@
1715
from vllm.v1.kv_cache_interface import AttentionSpec
1816
from vllm.v1.worker.block_table import BlockTable
1917

18+
_PARTITION_SIZE_ROCM = 256
19+
2020
if TYPE_CHECKING:
2121
from vllm.v1.core.sched.output import SchedulerOutput
2222
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -38,6 +38,9 @@ def _vllm_layout_trans_kernel(
3838
b_seq_lens_loc,
3939
block_table,
4040
block_table_stride_0,
41+
k_scale,
42+
v_scale,
43+
output_dtype: tl.constexpr,
4144
E_DIM: tl.constexpr,
4245
BLOCK_SIZE: tl.constexpr,
4346
):
@@ -59,16 +62,27 @@ def _vllm_layout_trans_kernel(
5962
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
6063

6164
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
62-
block_idx)
65+
block_idx).to(tl.int64)
6366

6467
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
6568
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
6669
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
6770
mask=block_mask,
6871
other=0.0)
72+
if k_vals.dtype.is_fp8():
73+
k_vals = (k_vals.to(tl.float32) *
74+
tl.load(k_scale)).to(output_dtype)
75+
else:
76+
k_vals = k_vals.to(output_dtype)
77+
6978
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
7079
mask=block_mask,
7180
other=0.0)
81+
if v_vals.dtype.is_fp8():
82+
v_vals = (v_vals.to(tl.float32) *
83+
tl.load(v_scale)).to(output_dtype)
84+
else:
85+
v_vals = v_vals.to(output_dtype)
7286

7387
kv_values_off = batch_token_start * E_DIM + \
7488
block_idx * BLOCK_SIZE * E_DIM + \
@@ -78,21 +92,28 @@ def _vllm_layout_trans_kernel(
7892
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
7993

8094
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
81-
k_buffer, v_buffer, max_seq_len, total_tokens):
95+
k_buffer, v_buffer, max_seq_len, total_tokens,
96+
k_scale, v_scale, output_dtype):
8297
H_KV = v_buffer.shape[2]
8398
D = v_buffer.shape[3]
8499
BLOCK_SIZE = v_buffer.shape[1]
85-
dtype = k_buffer.dtype
86100
k_values = torch.empty((total_tokens, H_KV, D),
87-
dtype=dtype,
101+
dtype=output_dtype,
88102
device="cuda")
89103
v_values = torch.empty((total_tokens, H_KV, D),
90-
dtype=dtype,
104+
dtype=output_dtype,
91105
device="cuda")
92106

93107
grid = (block_table.shape[0],
94108
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
95109

110+
if output_dtype == torch.float16:
111+
output_dtype = tl.float16
112+
elif output_dtype == torch.bfloat16:
113+
output_dtype = tl.bfloat16
114+
else:
115+
raise ValueError(f"Unsupported output dtype: {output_dtype}")
116+
96117
_vllm_layout_trans_kernel[grid](k_buffer,
97118
v_buffer,
98119
k_values,
@@ -101,6 +122,9 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
101122
b_seq_lens_loc,
102123
block_table,
103124
block_table.stride(0),
125+
k_scale,
126+
v_scale,
127+
output_dtype=output_dtype,
104128
E_DIM=H_KV * D,
105129
BLOCK_SIZE=BLOCK_SIZE)
106130

@@ -120,9 +144,12 @@ def flash_attn_varlen_func_impl(
120144
window_size: Optional[list[int]], # -1 means infinite context window
121145
alibi_slopes: Optional[list[float]],
122146
block_table: torch.Tensor,
147+
k_scale: torch.Tensor,
148+
v_scale: torch.Tensor,
123149
) -> torch.Tensor:
124150
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
125-
k_cache, v_cache, max_seqlen_k, total_tokens)
151+
k_cache, v_cache, max_seqlen_k, total_tokens,
152+
k_scale, v_scale, q.dtype)
126153
output = aiter.flash_attn_varlen_func(
127154
q=q,
128155
k=k,
@@ -154,6 +181,8 @@ def flash_attn_varlen_func_fake(
154181
window_size: Optional[list[int]], # -1 means infinite context window
155182
alibi_slopes: Optional[list[float]],
156183
block_table: torch.Tensor,
184+
k_scale: torch.Tensor,
185+
v_scale: torch.Tensor,
157186
) -> torch.Tensor:
158187
return torch.empty(q.shape[0],
159188
q.shape[1],
@@ -184,7 +213,6 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
184213
self.block_size = kv_cache_spec.block_size
185214
self.kv_cache_spec = kv_cache_spec
186215
self.block_table = block_table
187-
188216
# Sliding window size to be used with the AOT scheduler will be
189217
# populated on first build() call.
190218
self.aot_sliding_window: Optional[tuple[int, int]] = None
@@ -281,6 +309,18 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281309
prefix_kv_lens = None
282310
suffix_kv_lens = None
283311

312+
nbyes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8
313+
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
314+
1) // _PARTITION_SIZE_ROCM
315+
316+
workspace_buffer = torch.empty(
317+
(num_reqs * self.num_heads_q * max_num_partitions * self.headdim) *
318+
nbyes_per_qo_elem + 2 *
319+
(num_reqs * self.num_heads_q * max_num_partitions) * 4,
320+
dtype=torch.uint8,
321+
device=self.runner.device,
322+
)
323+
284324
attn_metadata = AiterFlashAttentionMetadata(
285325
num_actual_tokens=num_actual_tokens,
286326
max_query_len=max_query_len,
@@ -292,6 +332,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
292332
block_table=block_table_tensor,
293333
slot_mapping=slot_mapping,
294334
use_cascade=use_cascade,
335+
workspace_buffer=workspace_buffer,
295336
common_prefix_len=common_prefix_len,
296337
cu_prefix_query_lens=cu_prefix_query_lens,
297338
prefix_kv_lens=prefix_kv_lens,
@@ -315,7 +356,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
315356

316357
@staticmethod
317358
def get_supported_head_sizes() -> list[int]:
318-
return [32, 64, 96, 128, 160, 192, 224, 256]
359+
return [64, 128, 256]
319360

320361
@staticmethod
321362
def get_name() -> str:
@@ -364,6 +405,7 @@ class AiterFlashAttentionMetadata:
364405
total_tokens: int
365406
block_table: torch.Tensor
366407
slot_mapping: torch.Tensor
408+
workspace_buffer: torch.Tensor
367409

368410
# For cascade attention.
369411
use_cascade: bool
@@ -442,10 +484,6 @@ def __init__(
442484
"are not implemented for "
443485
"FlashAttentionImpl")
444486
self.use_irope = use_irope
445-
if is_quantized_kv_cache(self.kv_cache_dtype):
446-
raise NotImplementedError(
447-
"AiterFlashAttention does not support fp8 kv-cache on this "
448-
"device.")
449487

450488
def forward(
451489
self,
@@ -516,12 +554,6 @@ def forward(
516554
if self.kv_cache_dtype.startswith("fp8"):
517555
key_cache = key_cache.view(torch.float8_e4m3fnuz)
518556
value_cache = value_cache.view(torch.float8_e4m3fnuz)
519-
num_tokens, num_heads, head_size = query.shape
520-
query, _ = ops.scaled_fp8_quant(
521-
query.reshape(
522-
(num_tokens, num_heads * head_size)).contiguous(),
523-
layer._q_scale)
524-
query = query.reshape((num_tokens, num_heads, head_size))
525557

526558
# Compute attention and update output up to `num_actual_tokens`.
527559
use_local_attn = \
@@ -559,28 +591,14 @@ def forward(
559591
alibi_slopes=self.alibi_slopes,
560592
window_size=self.sliding_window,
561593
block_table=block_table,
562-
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
563-
local_metadata.local_cu_seq_lens),
594+
cu_seqlens_k=cu_seq_lens,
595+
k_scale=layer._k_scale,
596+
v_scale=layer._v_scale,
564597
)
565598

566-
_, num_heads, head_size = query.shape
567-
_PARTITION_SIZE_ROCM = 256
568-
num_seqs = seqused_k.shape[0]
569-
nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8
570-
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
571-
1) // _PARTITION_SIZE_ROCM
572-
573-
workspace_buffer = torch.empty(
574-
(num_seqs * num_heads * max_num_partitions * head_size) *
575-
nbyes_per_qo_elem + 2 *
576-
(num_seqs * num_heads * max_num_partitions) * 4,
577-
dtype=torch.uint8,
578-
device=output.device,
579-
)
580-
581-
aiter.paged_attention_v1(
599+
torch.ops.aiter.paged_attention_v1(
582600
output[:num_actual_tokens],
583-
workspace_buffer,
601+
attn_metadata.workspace_buffer,
584602
query[:num_actual_tokens],
585603
key_cache,
586604
value_cache,

0 commit comments

Comments
 (0)