From 75a43a85b8e38b5361c11a38c5cb3e8cfbfeccbc Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 15 May 2025 21:00:08 +0000 Subject: [PATCH] Revert to the previous FA triton kernel Signed-off-by: Gregory Shtrasberg --- vllm/attention/backends/rocm_flash_attn.py | 5 +- vllm/attention/ops/triton_flash_attention.py | 1766 +++++++----------- vllm/platforms/rocm.py | 6 + 3 files changed, 694 insertions(+), 1083 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8076c4791d3c..83c83e684a44 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -770,8 +770,9 @@ def forward( and layer._v_scale and layer._prob_scale and self.kv_cache_dtype == "fp8") full_scales = ( - layer._q_scale, layer._k_scale, layer._v_scale, - layer._prob_scale) if use_fp8_scales else None + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), + layer._prob_scale.item()) if use_fp8_scales else None self.triton_attn_func( query, key, diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 8940d0b66225..62cfb813d5f9 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -1,236 +1,33 @@ +#!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 """ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm -See https://tridao.me/publications/flash2/flash2.pdf +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team -Credits: -AMD Triton kernels team -OpenAI kernel team - -Currently only the forward kernel is supported, and contains these features: +Features supported: 1) Fwd with causal masking -2) Arbitrary Q and KV sequence lengths -3) Arbitrary head sizes -4) Multi and grouped query attention -5) Variable sequence lengths -6) ALiBi and matrix bias -7) FP8 support +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. -""" +Not currently supported: -from typing import Optional +1) Non power of two head dims + +""" import torch -from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx1x from vllm.triton_utils import tl, triton -SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] - -default_eight_bit_dtype_triton = tl.float8e4b8 -default_eight_bit_dtype_torch = current_platform.fp8_dtype() -default_float8_info = torch.finfo(default_eight_bit_dtype_torch) - -FP8_MIN = triton.language.constexpr(default_float8_info.min) - -# According to https://github.com/vllm-project/vllm/blob/main -# /csrc/quantization/utils.cuh#L31, -# need to make the max for the uz datatype be 224.0 for accuracy reasons. -FP8_MAX = triton.language.constexpr( - default_float8_info.max if default_eight_bit_dtype_torch != - torch.float8_e4m3fnuz else 224.0) - - -class MetaData: - cu_seqlens_q = None - cu_seqlens_k = None - max_seqlens_q = 0 - max_seqlens_k = 0 - bias = None - alibi_slopes = None - causal = False - num_contexts = 0 - varlen = False - eight_bit = False - layout = None - return_encoded_softmax = False - eight_bit_dtype_triton = default_eight_bit_dtype_triton - eight_bit_dtype_torch = default_eight_bit_dtype_torch - output_dtype = None - - # Note about layouts: - # - # thd - [num_tokens, num_heads, head_size] - # bshd - [batch_size, seq_len, num_heads, head_size] - # bhsd - [batch_size, num_heads, seq_len, head_size] - # - # This is for each tensor, all tensors must have same layout. - # Q can have num_heads and seq_len differ from from K and V, - # however K and V must agree on this. - # - # Notes about varlen and bias: - # Only one or the other is implemented, meaning can't combine - # both varlen and bias right now. - # - # Note about quantization: - # Only 8-bit quantization supported (for now) and specifically fp8. - # Scales must be tensors. - # o_scale: This is 'output scaling', but comes from parameter called - # 'input_scale', this is applied to the output from the kernel. - # o_scale should be None if none of the other quantization parameters - # are used. - # - # NOTE: Object is in a tentatively good state after initialized, however, - # to verify, call check_args(q,k,v,o) where o is the output tensor. - def __init__( - self, - sm_scale=1.0, - layout=None, # layout can be 'bshd', 'bhsd', or 'thd' - output_dtype=None, - max_seqlens_q=0, - max_seqlens_k=0, - # varlen params - cu_seqlens_q=None, # only 'thd' layout supported for varlen - cu_seqlens_k=None, - # quant params - q_descale=None, - k_descale=None, - v_descale=None, - p_scale=None, - o_scale=None, - # bias params - bias=None, # varlen not implemented for bias - seqlen_q=None, - seqlen_k=None, - # alibi params - alibi_slopes=None, - alibi_batch=None, - alibi_nheads=None, - # causal - causal=None, - ): - self.sm_scale = sm_scale - self.output_dtype = output_dtype - self.max_seqlens_q = max_seqlens_q - self.max_seqlens_k = max_seqlens_k - self.layout = layout - if cu_seqlens_q is not None or cu_seqlens_k is not None: - assert cu_seqlens_q is not None and cu_seqlens_k is not None - assert layout is None or layout not in [ - 'bshd', 'bhsd' - ], "Varlen only implemented for thd layout" - self.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale] - if any(x is not None for x in quant_params): - p_descale = 1.0 / p_scale if p_scale is not None else None - self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale, - p_descale, o_scale) - if bias is not None: - self.need_bias(bias, seqlen_q, seqlen_k) - if alibi_slopes is not None: - self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads) - if causal is not None and causal: - self.need_causal() - - def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): - self.varlen = True - self.layout = 'thd' - self.cu_seqlens_q = cu_seqlens_q - self.cu_seqlens_k = cu_seqlens_k - # Without "varlen", there should still be one sequence. - assert len(cu_seqlens_q) >= 2 - assert len(cu_seqlens_q) == len(cu_seqlens_k) - self.num_contexts = len(cu_seqlens_q) - 1 - for i in range(0, self.num_contexts): - self.max_seqlens_q = max( - cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), - self.max_seqlens_q) - self.max_seqlens_k = max( - cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), - self.max_seqlens_k) - - def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale, - p_descale, o_scale): - self.eight_bit = True - self.q_descale = q_descale - self.k_descale = k_descale - self.v_descale = v_descale - self.p_scale = p_scale - self.p_descale = p_descale - self.o_scale = o_scale - self.use_p_scale = (p_scale is not None) and ( - p_descale is not None) and (v_descale is not None) - self.eight_bit_kv = ((q_descale is None) and (k_descale is not None) - and (v_descale is not None)) - self.eight_bit_dtype_torch = default_eight_bit_dtype_torch - - def need_bias(self, bias, seqlen_q, seqlen_k): - assert bias is not None - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.shape[0] == 1 - assert bias.shape[2:] == (seqlen_q, seqlen_k) - self.bias = bias - - def need_alibi(self, alibi_slopes, batch, nheads): - assert alibi_slopes.is_cuda - assert alibi_slopes.dim() == 2 - assert alibi_slopes.shape[0] == batch - assert alibi_slopes.shape[1] == nheads - self.alibi_slopes = alibi_slopes - - def need_causal(self): - self.causal = True - - def check_args(self, q, k, v, o): - assert q.dim() == k.dim() and q.dim() == v.dim() - - batch, nheads_q, nheads_k, head_size = get_shape_from_layout( - q, k, self) - if self.varlen: - assert q.dim() == 3 - assert self.cu_seqlens_q is not None - assert self.cu_seqlens_k is not None - assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) - # TODO: Remove once bias is supported with varlen - assert self.bias is None - assert not self.return_encoded_softmax - else: - assert q.dim() == 4 - assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 - assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - if self.eight_bit: - if self.eight_bit_kv: - assert (v.dtype == k.dtype - and k.dtype == self.eight_bit_dtype_torch) - assert q.dtype != k.dtype - assert (self.v_descale is not None) and (self.k_descale - is not None) - else: - assert (q.dtype == k.dtype and q.dtype == v.dtype - and q.dtype == self.eight_bit_dtype_torch) - assert (self.q_descale - is not None) and (self.k_descale - is not None) and (self.v_descale - is not None) - if self.use_p_scale: - assert (self.p_scale is not None) and (self.p_descale - is not None) - else: - assert (q.dtype == k.dtype) and (q.dtype == v.dtype) - assert head_size <= 256 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - assert self.layout is not None - assert self.layout == 'thd' or not self.varlen +torch_dtype: tl.constexpr = torch.float16 @triton.jit @@ -243,85 +40,40 @@ def max_fn(x, y): return tl.math.max(x, y) -# Convenience function to load with optional boundary checks. -# "First" is the major dim, "second" is the minor dim. @triton.jit -def masked_load(ptrs, offset_first, offset_second, boundary_first, - boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) - else: - tensor = tl.load(ptrs) - return tensor +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] @triton.jit -def compute_alibi_block(alibi_slope, - seqlen_q, - seqlen_k, - offs_m, - offs_n, - transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to - # the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is - # masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that - # spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, - # offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = - # [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q - - offs_n[None, :]) - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) -def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, - device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, - device="cuda").unsqueeze(0) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze( - -1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep @triton.jit -def quant_fp8(x, scale): - x *= scale - x = tl.clamp(x, FP8_MIN, FP8_MAX) - return x +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor @triton.jit @@ -330,68 +82,61 @@ def _attn_fwd_inner( l_i, m_i, q, - k_ptrs, - v_ptrs, - bias_ptrs, - stride_kn, - stride_vk, - stride_bn, + K_block_ptr, + V_block_ptr, start_m, actual_seqlen_k, - actual_seqlen_q, + dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, + encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, - alibi_slope, - q_descale, - k_descale, - v_descale, - p_scale, + bias_ptr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, - SHOULD_PRE_LOAD_V: tl.constexpr, - SHOULD_MASK_STEPS: tl.constexpr, - SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_PADDED_HEAD: tl.constexpr, - IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, - QK_SCALE: tl.constexpr, - IS_EIGHT_BIT_GEMM: tl.constexpr, - USE_P_SCALE: tl.constexpr, - IS_EIGHT_BIT_KV: tl.constexpr, - QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, ): - # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k_offs_n = start_n + tl.arange(0, - BLOCK_N) if SHOULD_MASK_STEPS else None - k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, - actual_seqlen_k) - if SHOULD_PRE_LOAD_V: - # We can use the same offsets as k, just with dims transposed. - v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, - IS_ACTUAL_BLOCK_DMODEL) + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if SHOULD_MASK_STEPS: # noqa: SIM102 + if MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not - # is_modulo_mn. last step might get wasted but that is okay. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], @@ -404,107 +149,112 @@ def _attn_fwd_inner( causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - if IS_EIGHT_BIT_GEMM: - qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * - QK_SCALE) - else: - if IS_EIGHT_BIT_KV: - k = (k * k_descale).to(q.type.element_ty) - qk += (tl.dot(q, k) * QK_SCALE) - - if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange( - 0, BLOCK_N) if SHOULD_MASK_STEPS else None - bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, - actual_seqlen_k) - # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an - # additional scale factor of log2(e) which we must also multiply - # the bias with. - qk += (bias * 1.44269504089) - - if alibi_slope is not None: - # Compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, - actual_seqlen_k, - global_m_positions, - global_n_positions) - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) - - # softmax + qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) - if SHOULD_RETURN_ENCODED_SOFTMAX: - tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not SHOULD_PRE_LOAD_V: - v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, - IS_ACTUAL_BLOCK_DMODEL) + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - if IS_EIGHT_BIT_GEMM: - if USE_P_SCALE: - p = quant_fp8(p, p_scale).to(QUANT_DTYPE) - acc += tl.dot(p, v) - else: - # v is in eight_bit but p is not, we want the gemm in p's type - acc += tl.dot(p, v.to(p.type.element_ty)) - else: - if IS_EIGHT_BIT_KV: - v = (v * v_descale).to(p.type.element_ty) - acc += tl.dot(p.to(v.type.element_ty), v) - - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if bias_ptrs is not None: - bias_ptrs += BLOCK_N * stride_bn - if SHOULD_RETURN_ENCODED_SOFTMAX: - encoded_sm_ptrs += BLOCK_N + if USE_FP8: + p *= p_descale + + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) return acc, l_i, m_i def get_cdna_autotune_configs(): return [ + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), triton.Config( { 'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=4), triton.Config( { - 'BLOCK_M': 128, - 'BLOCK_N': 64, + 'BLOCK_M': 256, + 'BLOCK_N': 128, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4), + num_warps=8), triton.Config( { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 3, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, num_warps=4), @@ -512,168 +262,141 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 1, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'waves_per_eu': 3, + 'PRE_LOAD_V': True }, num_stages=1, num_warps=4), triton.Config( { 'BLOCK_M': 128, - 'BLOCK_N': 32, - 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, num_warps=4), - ], [ - 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', - 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' - ] - - -def get_rdna_autotune_configs(): - return [ triton.Config( { - 'BLOCK_M': 32, - 'BLOCK_N': 32, + 'BLOCK_M': 64, + 'BLOCK_N': 64, 'waves_per_eu': 4, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=2), + num_warps=8), triton.Config( { 'BLOCK_M': 32, 'BLOCK_N': 32, - 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=2), + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_rdna_autotune_configs(): + return [ triton.Config( { 'BLOCK_M': 32, - 'BLOCK_N': 16, + 'BLOCK_N': 32, 'waves_per_eu': 4, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), triton.Config( { 'BLOCK_M': 32, - 'BLOCK_N': 16, + 'BLOCK_N': 32, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), triton.Config( { - 'BLOCK_M': 16, + 'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), triton.Config( { - 'BLOCK_M': 16, + 'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 1, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), - ], [ - 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', - 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' - ] - - -def get_general_autotune_configs(): - return [ - triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 128, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=4), - triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=4), - triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 32, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=4), - ], [ - 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', - 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' - ] - - -def has_cdna_target(): - ROCM_CDNA_TARGETS = ["gfx942", "gfx90a", "gfx908"] - return triton.runtime.driver.active.get_current_target( - ).arch in ROCM_CDNA_TARGETS - - -def is_rocm_cdna(): - return current_platform.is_rocm() and has_cdna_target() + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] def get_autotune_configs(): - if is_rocm_cdna(): - return get_cdna_autotune_configs() - elif current_platform.is_rocm(): + if on_gfx1x(): return get_rdna_autotune_configs() else: - return get_general_autotune_configs() + return get_cdna_autotune_configs() autotune_configs, autotune_keys = get_autotune_configs() +float8_info = torch.finfo(current_platform.fp8_dtype()) + @triton.autotune( configs=autotune_configs, key=autotune_keys, - use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -681,7 +404,13 @@ def attn_fwd( K, V, bias, - SM_SCALE: tl.constexpr, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, L, Out, stride_qz: tl.int64, @@ -704,70 +433,44 @@ def attn_fwd( stride_bh: tl.int64, stride_bm: tl.int64, stride_bn: tl.int64, - stride_az: tl.int64, - stride_ah: tl.int64, - q_descale_ptr, - k_descale_ptr, - p_scale_ptr, - p_descale_ptr, - o_descale_ptr, - v_descale_ptr, - q_descale_has_singleton: tl.constexpr, - k_descale_has_singleton: tl.constexpr, - p_descale_has_singleton: tl.constexpr, - v_descale_has_singleton: tl.constexpr, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, - NUM_CU: tl.constexpr, - GRID_CU_MULTIP: tl.constexpr, - B: tl.constexpr, philox_offset_base, encoded_softmax, - alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, - IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, + USE_FP8_OUT: tl.constexpr, BLOCK_N: tl.constexpr, - SHOULD_PRE_LOAD_V: tl.constexpr, - USE_BIAS: tl.constexpr, - SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_ALIBI: tl.constexpr, - IS_EIGHT_BIT: tl.constexpr, - USE_P_SCALE: tl.constexpr, - IS_EIGHT_BIT_KV: tl.constexpr, - QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): - - if o_descale_ptr is not None: - o_descale = tl.load(o_descale_ptr) - - start_m: tl.int64 = tl.program_id(0) - off_h_q: tl.int64 = tl.program_id(1) - off_z: tl.int64 = tl.program_id(2) - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64) - offs_n = tl.arange(0, BLOCK_N).to(tl.int64) - offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64) - - # as we can't have return statements inside while loop in Triton - continue_condition = True - + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be - # too small for all start_m so for those we return early. + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: - continue_condition = False - # return + return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start @@ -777,598 +480,499 @@ def attn_fwd( seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - if continue_condition: - # Now we compute whether we need to exit early due to causal - # masking. This is because for seqlen_q > seqlen_k, M rows of the - # attn scores are completely masked, resulting in 0s written to the - # output, and inf written to LSE. We don't need to do any GEMMs in - # this case. This block of code determines what N is, and if this - # WG is operating on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if (IS_CAUSAL): - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which - # means the causal mask boundary is bottom right aligned, and - # ends at either the top edge (seqlen_q < seqlen_k) or left - # edge. This captures the decrease in n_blocks if we have a - # rectangular attn matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all - # n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this - # WG is part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + - cu_seqlens_q_start * stride_om) - o_ptrs = (o_offset + offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( - [BLOCK_M, BLOCK_DMODEL]) - # We still need to write 0s to the result - tl.store(o_ptrs, acc, mask=o_ptrs_mask) - # The tensor allocated for L is based on MAX_SEQLENS_Q as - # that is statically known. - l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + - off_h_q * MAX_SEQLENS_Q + offs_m) - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this from qk which makes it -inf, such that - # exp(qk - inf) = 0 for these masked blocks. - l_value = tl.full([BLOCK_M], - value=float("inf"), - dtype=tl.float32) - l_ptrs_mask = offs_m < MAX_SEQLENS_Q - tl.store(l_ptrs, l_value, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be - # handled here too? - continue_condition = False - # return - - if continue_condition: - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL - != BLOCK_DMODEL) - - # Compute pointers for all the tensors used in this kernel. - q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - q_ptrs = (q_offset + offs_m[:, None] * stride_qm + - offs_d[None, :] * stride_qk) - k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - k_ptrs = (k_offset + offs_d[:, None] * stride_kk + - offs_n[None, :] * stride_kn) - v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - v_ptrs = (v_offset + offs_n[:, None] * stride_vk + - offs_d[None, :] * stride_vn) - # Compute pointers for all scale tensors used in this kernel. - - IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( - not IS_EIGHT_BIT_KV) - if IS_EIGHT_BIT: - if k_descale_has_singleton: - k_descale_ptrs = k_descale_ptr - else: - k_descale_ptrs = k_descale_ptr + off_h_k - - if v_descale_has_singleton: - v_descale_ptrs = v_descale_ptr - else: - v_descale_ptrs = v_descale_ptr + off_h_k - - if not IS_EIGHT_BIT_KV: - if q_descale_has_singleton: - q_descale_ptrs = q_descale_ptr - else: - q_descale_ptrs = q_descale_ptr + off_h_q - if USE_P_SCALE: - if p_descale_has_singleton: - p_scale_ptrs = p_scale_ptr - p_descale_ptrs = p_descale_ptr - else: - p_scale_ptrs = p_scale_ptr + off_h_q - p_descale_ptrs = p_descale_ptr + off_h_q - - if USE_BIAS: - bias_offset = off_h_q * stride_bh - bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm + - offs_n[None, :] * stride_bn) - else: - bias_ptrs = None - - if USE_ALIBI: - a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(alibi_slopes + a_offset) - else: - alibi_slope = None - - batch_philox_offset = 0 - # We can ask to return the dropout mask without doing any - # dropout. In this case, we return an invalid pointer so - # indicate the mask is not valid. - if SHOULD_RETURN_ENCODED_SOFTMAX: - encoded_sm_base = (encoded_softmax + - off_h_q * seqlen_q * seqlen_k) - encoded_sm_ptrs = (encoded_sm_base + - offs_m[:, None] * seqlen_k + - offs_n[None, :]) - else: - encoded_sm_ptrs = None - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do - # not have native e^x support in HW. - QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q_ptrs_mask = offs_m[:, None] < seqlen_q - if USE_PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] - < IS_ACTUAL_BLOCK_DMODEL) - q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - - if IS_EIGHT_BIT: - k_descale = tl.load(k_descale_ptrs) - v_descale = tl.load(v_descale_ptrs) - q_descale = None if IS_EIGHT_BIT_KV else tl.load( - q_descale_ptrs) - if USE_P_SCALE: - p_scale = tl.load(p_scale_ptrs) - p_descale = tl.load(p_descale_ptrs) - else: - p_scale = None - p_descale = None - else: - q_descale = None - k_descale = None - v_descale = None - p_scale = None - p_descale = None - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked - # blocks. Additionally there might be one more due to - # dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an - # additional block. In this case we might exceed n_blocks so - # pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false - # regardless of its actual value because there is no masking. - # Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - bias_ptrs, - stride_kn, - stride_vk, - stride_bn, - start_m, - seqlen_k, - seqlen_q, - philox_seed, - batch_philox_offset, - encoded_sm_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - alibi_slope, - q_descale, - k_descale, - v_descale, - p_scale, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, SHOULD_MASK_STEPS, ... - SHOULD_PRE_LOAD_V, - False, - SHOULD_RETURN_ENCODED_SOFTMAX, - USE_PADDED_HEAD, - IS_ACTUAL_BLOCK_DMODEL, - QK_SCALE, - IS_EIGHT_BIT_GEMM, - USE_P_SCALE, - IS_EIGHT_BIT_KV, - QUANT_DTYPE) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vk - if USE_BIAS: - bias_ptrs += n_full_blocks * BLOCK_N * stride_bn - if SHOULD_RETURN_ENCODED_SOFTMAX: - encoded_sm_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - bias_ptrs, - stride_kn, - stride_vk, - stride_bn, - start_m, - seqlen_k, - seqlen_q, - philox_seed, - batch_philox_offset, - encoded_sm_ptrs, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - alibi_slope, - q_descale, - k_descale, - v_descale, - p_scale, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, SHOULD_MASK_STEPS, ... - SHOULD_PRE_LOAD_V, - True, - SHOULD_RETURN_ENCODED_SOFTMAX, - USE_PADDED_HEAD, - IS_ACTUAL_BLOCK_DMODEL, - QK_SCALE, - IS_EIGHT_BIT_GEMM, - USE_P_SCALE, - IS_EIGHT_BIT_KV, - QUANT_DTYPE) - - if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: - if USE_P_SCALE: - acc *= p_descale - acc *= v_descale - - # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc - # which is much larger. - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - - # If seqlen_q > seqlen_k but the delta is not a multiple of - # BLOCK_M, then we have one block with a row of all NaNs which - # come from computing softmax over a row of all - # -infs (-inf - inf = NaN). We check for that here and store 0s - # where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 - if o_descale_ptr is not None: - acc = quant_fp8(acc, o_descale) - - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if (causal_start_idx > start_m_idx - and causal_start_idx < end_m_idx): - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = tl.zeros((1, ), tl.float32) - acc = tl.where(out_ptrs_mask, acc, - z.to(acc.type.element_ty)) - # write back LSE - l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + - off_h_q * MAX_SEQLENS_Q + offs_m) - # If seqlen_q not multiple of BLOCK_M, we need to mask out the - # last few rows. This is only true for the last M block. - # For others, overflow_size will be -ve - overflow_size = end_m_idx - seqlen_q - if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), - BLOCK_M - overflow_size, - dtype=tl.int32) - l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - else: - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + - cu_seqlens_q_start * stride_om) - o_ptrs = (o_offset + offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) - if overflow_size > 0: - o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if USE_PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] - < IS_ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) - - -def get_shape_from_layout(q, k, metadata): - assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." - - if metadata.layout == 'thd': - nheads_q, nheads_k = q.shape[1], k.shape[1] - head_size = q.shape[-1] - batch = metadata.num_contexts - elif metadata.layout == 'bhsd': - batch, nheads_q, _, head_size = q.shape - nheads_k = k.shape[1] - elif metadata.layout == 'bshd': - batch, _, nheads_q, head_size = q.shape - nheads_k = k.shape[2] - return batch, nheads_q, nheads_k, head_size - - -def get_strides_from_layout(q, k, v, o, metadata): - assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." - - STRIDE_PERMUTATIONS = { - 'thd': (None, 1, 0, 2), - 'bhsd': (0, 1, 2, 3), - 'bshd': (0, 2, 1, 3), - } - - perm = STRIDE_PERMUTATIONS[metadata.layout] - stride = lambda x, p: (0 if p is None else x.stride(p)) - strides = lambda x: (stride(x, p) for p in perm) - - return tuple(strides(x) for x in [q, k, v, o]) + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + # epilogue + + if USE_FP8: + acc *= acc_scale + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if USE_FP8_OUT: + acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, o, metadata: MetaData): - # NOTE: a large bias tensor leads to overflow during pointer arithmetic - if (metadata.bias is not None): - assert (metadata.bias.numel() < 2**31) + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + fp8_scales=None, + fp8_out_scale=None, + ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + float8 = current_platform.fp8_dtype() + + def check_and_convert(t, scale): + if t.dtype != float8: + descale = 1.0 / scale + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) + return ts.to(float8) + else: + return t - if o is None: - if metadata.eight_bit: - o = torch.empty_like( - q, - dtype=metadata.output_dtype if metadata.output_dtype - is not None else metadata.eight_bit_dtype_torch) - else: - o = torch.empty_like(q, dtype=q.dtype) + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = 1.0 - metadata.check_args(q, k, v, o) + if o is None: + o = torch.empty_like(q, dtype=v.dtype) - batch, nheads_q, nheads_k, head_size = get_shape_from_layout( - q, k, metadata) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( - q, k, v, o, metadata) + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - # Smallest head_dim supported is 16. If smaller, the tile in the - # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) - - # encoded_softmax is used to validate dropout behavior vs the - # PyTorch SDPA math backend reference. We zero this out to give a - # consistent starting point and then populate it with the output of - # softmax with the sign bit set according to the dropout mask. - # The resulting return allows this mask to be fed into the reference - # implementation for testing only. This return holds no useful output - # aside from debugging. - if metadata.return_encoded_softmax: - encoded_softmax = torch.zeros( - (q.shape[0], q.shape[1], q.shape[2], k.shape[2]), - device=q.device, - dtype=torch.float32) + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None else: - encoded_softmax = None + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) - M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), - device=q.device, - dtype=torch.float32) + encoded_softmax = None # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 - if metadata.bias is not None: - bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), - metadata.bias.stride(2), metadata.bias.stride(3)) + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) else: bias_strides = (0, 0, 0, 0) - if metadata.alibi_slopes is not None: - alibi_strides = (metadata.alibi_slopes.stride(0), - metadata.alibi_slopes.stride(1)) - else: - alibi_strides = (0, 0) + p_descale = 1.0 / p_scale + o_descale = 1.0 / fp8_out_scale.item( + ) if fp8_out_scale is not None else 1.0 - if metadata.eight_bit: - q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( - metadata.q_descale, metadata.k_descale, metadata.p_scale, - metadata.p_descale, metadata.v_descale, metadata.o_scale) - o_descale = 1.0 / o_scale if o_scale is not None else None - else: - q_descale = k_descale = p_scale = None - p_descale = v_descale = o_descale = None - - # number of compute units available - NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count - - grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[ - 'BLOCK_M']), nheads_q, batch) + arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q + arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k attn_fwd[grid]( q, k, v, - metadata.bias, - metadata.sm_scale, - M, + bias, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, + None, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - *alibi_strides, - q_descale, - k_descale, - p_scale, - p_descale, - o_descale, - v_descale, - q_descale.numel() == 1 if q_descale is not None else False, - k_descale.numel() == 1 if k_descale is not None else False, - p_descale.numel() == 1 if p_descale is not None else False, - v_descale.numel() == 1 if v_descale is not None else False, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, - IS_ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=metadata.max_seqlens_q, - MAX_SEQLENS_K=metadata.max_seqlens_k, - IS_CAUSAL=metadata.causal, - VARLEN=metadata.varlen, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, BLOCK_DMODEL=padded_d_model, - USE_BIAS=metadata.bias is not None, - USE_ALIBI=metadata.alibi_slopes is not None, - SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, - IS_EIGHT_BIT=metadata.eight_bit, - USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, - IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, - NUM_CU=NUM_CU, - B=batch, - QUANT_DTYPE=metadata.eight_bit_dtype_triton) + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + USE_FP8=use_fp8, + USE_FP8_OUT=fp8_out_scale is not None, + ) ctx.grid = grid - ctx.sm_scale = metadata.sm_scale + ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = head_size - ctx.causal = metadata.causal - ctx.alibi_slopes = metadata.alibi_slopes + ctx.causal = causal + ctx.dropout_p = 0.0 ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = metadata.return_encoded_softmax + ctx.return_encoded_softmax = False return o, encoded_softmax -triton_attention_rocm = _attention.apply - - -def scale_fp8(t, scale=None): - t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]), - scale) - return t_scaled.reshape(t.shape), scale_out - - -def maybe_quantize_fp8(t, scale): - eight_bit_dtype = current_platform.fp8_dtype() - if t.dtype != eight_bit_dtype: - t, _ = scale_fp8(t, scale) - return t - - -def check_and_maybe_quantize_qkv(q, k, v, fp8_scales): - (q_scale, k_scale, v_scale, p_scale) = fp8_scales - - q = maybe_quantize_fp8(q, q_scale) - k = maybe_quantize_fp8(k, k_scale) - v = maybe_quantize_fp8(v, v_scale) - - return q, k, v - - -# query - [num_tokens, num_heads, head_size] -# key - [num_tokens, num_kv_heads, head_size] -# value - [num_tokens, num_kv_heads, head_size -# output - [num_tokens, num_heads, head_size] -def triton_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlens_q: int, - max_seqlens_k: int, - causal: bool = False, - sm_scale: float = 1.0, - bias: Optional[torch.Tensor] = None, - fp8_scales: Optional[tuple[float, ...]] = None, - input_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - if fp8_scales is not None: - q_descale, k_descale, v_descale, p_scale = fp8_scales - else: - q_descale = k_descale = v_descale = p_scale = None - - attn_metadata = MetaData(sm_scale=sm_scale, - max_seqlens_q=max_seqlens_q, - max_seqlens_k=max_seqlens_k, - causal=causal, - bias=bias, - output_dtype=q.dtype, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - p_scale=p_scale, - o_scale=input_scale) - - if fp8_scales is not None: - q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales) - - return triton_attention_rocm(q, k, v, o, attn_metadata) +triton_attention = _attention.apply diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c8b86087578d..51b0a94257a0 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -95,6 +95,12 @@ def wrapper(*args, **kwargs): return wrapper +@cache +def on_gfx1x() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) + + @cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName