From b6b00d7032e1d1e83de3128e9f1f1132dd496502 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 5 Feb 2025 20:42:23 +0000 Subject: [PATCH 01/52] init Signed-off-by: Sage Moore --- csrc/attention/paged_attention_v1.cu | 3 + csrc/attention/paged_attention_v2.cu | 3 + examples/offline_inference/basic.py | 2 +- vllm/platforms/rocm.py | 29 +++--- vllm/v1/attention/backends/flash_attn.py | 122 ++++++++++++++++------- 5 files changed, 105 insertions(+), 54 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 9b3a5c4b1014..96316319af99 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -161,6 +161,9 @@ void paged_attention_v1_launcher( case 32: \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ + case 128: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 128, KV_DTYPE); \ + break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9935359e02fb..13768b5377c0 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -168,6 +168,9 @@ void paged_attention_v2_launcher( case 32: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ + case 128: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 128, KV_DTYPE); \ + break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index 23cc6e853943..d68c13ddbab6 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="meta-llama/meta-llama-3.1-8b-instruct") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5ef56406e193..1e7e09ad3ca9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -27,12 +27,6 @@ except ImportError as e: logger.warning("Failed to import from vllm._rocm_C with %r", e) -if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: - logger.warning("`fork` method is not supported by ROCm. " - "VLLM_WORKER_MULTIPROC_METHOD is overridden to" - " `spawn` instead.") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: List[str] = [] @@ -78,6 +72,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) -> str: selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) + if "VLLM_USE_V1" in os.environ: + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. @@ -122,16 +119,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_worker.MultiStepWorker" - elif vllm_config.speculative_config: - parallel_config.worker_cls = \ - "vllm.spec_decode.spec_decode_worker.create_spec_worker" - parallel_config.sd_worker_cls = \ - "vllm.worker.worker.Worker" - else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + # if scheduler_config.is_multi_step: + # parallel_config.worker_cls = \ + # "vllm.worker.multi_step_worker.MultiStepWorker" + # elif vllm_config.speculative_config: + # parallel_config.worker_cls = \ + # "vllm.spec_decode.spec_decode_worker.create_spec_worker" + # parallel_config.sd_worker_cls = \ + # "vllm.worker.worker.Worker" + # else: + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" @classmethod def verify_model_arch(cls, model_arch: str) -> None: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ce83b1fac6c0..5eb509676594 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -13,10 +13,12 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) - +# from vllm.vllm_flash_attn import (fa_version_unsupported_reason, +# flash_attn_varlen_func, +# is_fa_version_supported) +from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.attention.ops.paged_attn import PagedAttention +# from flash_attn import flash_attn_varlen_func logger = init_logger(__name__) @@ -139,7 +141,8 @@ def __init__( # TODO(lucas): profile FA3 on ampere to see if it makes sense to # use FA3 as default for both if current_platform.get_device_capability()[0] >= 9: - self.fa_version = 3 if is_fa_version_supported(3) else 2 + # self.fa_version = 3 if is_fa_version_supported(3) else 2 + self.fa_version = 2 else: self.fa_version = 2 @@ -147,12 +150,12 @@ def __init__( assert VLLM_FLASH_ATTN_VERSION in [2, 3] self.fa_version = VLLM_FLASH_ATTN_VERSION - if not is_fa_version_supported(self.fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - self.fa_version, - fa_version_unsupported_reason(self.fa_version)) + # if not is_fa_version_supported(self.fa_version): + # logger.error("Cannot use FA version %d is not supported due to %s", + # self.fa_version, + # fa_version_unsupported_reason(self.fa_version)) - assert is_fa_version_supported(self.fa_version) + # assert is_fa_version_supported(self.fa_version) def forward( self, @@ -196,40 +199,85 @@ def forward( # not padded. However, we don't need to do key[:num_actual_tokens] and # value[:num_actual_tokens] because the reshape_and_cache_flash op uses # the slot_mapping's shape to determine the number of actual tokens. - key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + # key_cache, value_cache = kv_cache.unbind(0) + # torch.ops._C_cache_ops.reshape_and_cache_flash( + # key, + # value, + # key_cache, + # value_cache, + # attn_metadata.slot_mapping, + # self.kv_cache_dtype, + # layer._k_scale, + # layer._v_scale, + # ) + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # Compute attention and update output up to `num_actual_tokens`. if not attn_metadata.use_cascade: # Regular attention (common case). - flash_attn_varlen_func( + # flash_attn_varlen_func( + # q=query[:num_actual_tokens], + # k=key_cache, + # v=value_cache, + # out=output[:num_actual_tokens], + # cu_seqlens_q=attn_metadata.query_start_loc, + # max_seqlen_q=attn_metadata.max_query_len, + # seqused_k=attn_metadata.seq_lens, + # max_seqlen_k=attn_metadata.max_seq_len, + # softmax_scale=self.scale, + # causal=True, + # alibi_slopes=self.alibi_slopes, + # window_size=self.sliding_window, + # block_table=attn_metadata.block_table, + # softcap=self.logits_soft_cap, + # fa_version=self.fa_version, + # ) + context_lens = torch.empty_like(attn_metadata.seq_lens) + batch_size = len(attn_metadata.query_start_loc) - 1 + assert len(context_lens) == batch_size + for i in range(batch_size): + query_start = attn_metadata.query_start_loc[i] + query_end = attn_metadata.query_start_loc[i + 1] + context_lens[i] = attn_metadata.seq_lens[i] - (query_end - query_start) + + # print(f"context: {context_lens} seqs: {attn_metadata.seq_lens} query: {attn_metadata.query_start_loc}") + + context_attention_fwd( q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - seqused_k=attn_metadata.seq_lens, - max_seqlen_k=attn_metadata.max_seq_len, - softmax_scale=self.scale, - causal=True, + k=key[:num_actual_tokens], + v=value[:num_actual_tokens], + o=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=attn_metadata.block_table, + b_start_loc=attn_metadata.query_start_loc, + b_seq_len=attn_metadata.seq_lens, + b_ctx_len=context_lens, + max_input_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=attn_metadata.block_table, - softcap=self.logits_soft_cap, - fa_version=self.fa_version, + sliding_window=self.sliding_window[0] ) return output + assert False # Cascade attention (rare case). cascade_attention( output[:num_actual_tokens], @@ -243,9 +291,9 @@ def forward( suffix_kv_lens=attn_metadata.suffix_kv_lens, max_kv_len=attn_metadata.max_seq_len, softmax_scale=self.scale, + causal=True, alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window, - logits_soft_cap=self.logits_soft_cap, + window_size=self.sliding_window, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.fa_version, From fa5226824822dab056a34031436c7fee0e92a7bf Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 5 Feb 2025 21:27:42 +0000 Subject: [PATCH 02/52] temporarily remove torch from requirements-build Signed-off-by: Sage Moore --- requirements-build.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-build.txt b/requirements-build.txt index fec01caaf25e..b1188240e90d 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -4,6 +4,6 @@ ninja packaging setuptools>=61 setuptools-scm>=8 -torch==2.5.1 +# torch==2.5.1 wheel jinja2 From f563276362f8ca7f1e2a0a4c957aa33852be91b1 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 6 Feb 2025 15:24:19 +0000 Subject: [PATCH 03/52] move rocm logic to its own attention backend Signed-off-by: Sage Moore --- requirements-build.txt | 1 + vllm/platforms/rocm.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 122 +++++++---------------- 3 files changed, 39 insertions(+), 86 deletions(-) diff --git a/requirements-build.txt b/requirements-build.txt index b1188240e90d..296eb8ca6bd2 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -7,3 +7,4 @@ setuptools-scm>=8 # torch==2.5.1 wheel jinja2 + diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1e7e09ad3ca9..90a50fbc0b5d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -74,7 +74,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, == _Backend.FLASH_ATTN else selected_backend) if "VLLM_USE_V1" in os.environ: logger.info("Using Flash Attention backend on V1 engine.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5eb509676594..ce83b1fac6c0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -13,12 +13,10 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -# from vllm.vllm_flash_attn import (fa_version_unsupported_reason, -# flash_attn_varlen_func, -# is_fa_version_supported) -from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.attention.ops.paged_attn import PagedAttention -# from flash_attn import flash_attn_varlen_func +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported) + logger = init_logger(__name__) @@ -141,8 +139,7 @@ def __init__( # TODO(lucas): profile FA3 on ampere to see if it makes sense to # use FA3 as default for both if current_platform.get_device_capability()[0] >= 9: - # self.fa_version = 3 if is_fa_version_supported(3) else 2 - self.fa_version = 2 + self.fa_version = 3 if is_fa_version_supported(3) else 2 else: self.fa_version = 2 @@ -150,12 +147,12 @@ def __init__( assert VLLM_FLASH_ATTN_VERSION in [2, 3] self.fa_version = VLLM_FLASH_ATTN_VERSION - # if not is_fa_version_supported(self.fa_version): - # logger.error("Cannot use FA version %d is not supported due to %s", - # self.fa_version, - # fa_version_unsupported_reason(self.fa_version)) + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) - # assert is_fa_version_supported(self.fa_version) + assert is_fa_version_supported(self.fa_version) def forward( self, @@ -199,85 +196,40 @@ def forward( # not padded. However, we don't need to do key[:num_actual_tokens] and # value[:num_actual_tokens] because the reshape_and_cache_flash op uses # the slot_mapping's shape to determine the number of actual tokens. - # key_cache, value_cache = kv_cache.unbind(0) - # torch.ops._C_cache_ops.reshape_and_cache_flash( - # key, - # value, - # key_cache, - # value_cache, - # attn_metadata.slot_mapping, - # self.kv_cache_dtype, - # layer._k_scale, - # layer._v_scale, - # ) - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + key_cache, value_cache = kv_cache.unbind(0) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # Compute attention and update output up to `num_actual_tokens`. if not attn_metadata.use_cascade: # Regular attention (common case). - # flash_attn_varlen_func( - # q=query[:num_actual_tokens], - # k=key_cache, - # v=value_cache, - # out=output[:num_actual_tokens], - # cu_seqlens_q=attn_metadata.query_start_loc, - # max_seqlen_q=attn_metadata.max_query_len, - # seqused_k=attn_metadata.seq_lens, - # max_seqlen_k=attn_metadata.max_seq_len, - # softmax_scale=self.scale, - # causal=True, - # alibi_slopes=self.alibi_slopes, - # window_size=self.sliding_window, - # block_table=attn_metadata.block_table, - # softcap=self.logits_soft_cap, - # fa_version=self.fa_version, - # ) - context_lens = torch.empty_like(attn_metadata.seq_lens) - batch_size = len(attn_metadata.query_start_loc) - 1 - assert len(context_lens) == batch_size - for i in range(batch_size): - query_start = attn_metadata.query_start_loc[i] - query_end = attn_metadata.query_start_loc[i + 1] - context_lens[i] = attn_metadata.seq_lens[i] - (query_end - query_start) - - # print(f"context: {context_lens} seqs: {attn_metadata.seq_lens} query: {attn_metadata.query_start_loc}") - - context_attention_fwd( + flash_attn_varlen_func( q=query[:num_actual_tokens], - k=key[:num_actual_tokens], - v=value[:num_actual_tokens], - o=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=attn_metadata.block_table, - b_start_loc=attn_metadata.query_start_loc, - b_seq_len=attn_metadata.seq_lens, - b_ctx_len=context_lens, - max_input_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + seqused_k=attn_metadata.seq_lens, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0] + window_size=self.sliding_window, + block_table=attn_metadata.block_table, + softcap=self.logits_soft_cap, + fa_version=self.fa_version, ) return output - assert False # Cascade attention (rare case). cascade_attention( output[:num_actual_tokens], @@ -291,9 +243,9 @@ def forward( suffix_kv_lens=attn_metadata.suffix_kv_lens, max_kv_len=attn_metadata.max_seq_len, softmax_scale=self.scale, - causal=True, alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.fa_version, From 2a03b92f609275946bb44d0d5b01f127d704bded Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 6 Feb 2025 19:10:25 +0000 Subject: [PATCH 04/52] actually add backend Signed-off-by: Sage Moore --- vllm/v1/attention/backends/flash_attn.py | 6 +- vllm/v1/attention/backends/rocm_attn.py | 229 +++++++++++++++++++++++ 2 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 vllm/v1/attention/backends/rocm_attn.py diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ce83b1fac6c0..1768f99790bb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -13,9 +13,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) +# from vllm.vllm_flash_attn import (fa_version_unsupported_reason, +# flash_attn_varlen_func, +# is_fa_version_supported) logger = init_logger(__name__) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py new file mode 100644 index 000000000000..b441d0b7bc6f --- /dev/null +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -0,0 +1,229 @@ +"""Attention layer with PagedAttention on rocm""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import torch +import triton +import triton.language as tl + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.attention.ops.paged_attn import PagedAttention +logger = init_logger(__name__) + + +class ROCmAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +@dataclass +class FlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + +class FlashAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + # self.fa_version = 3 if is_fa_version_supported(3) else 2 + self.fa_version = 2 + else: + self.fa_version = 2 + + if VLLM_FLASH_ATTN_VERSION is not None: + assert VLLM_FLASH_ATTN_VERSION in [2, 3] + self.fa_version = VLLM_FLASH_ATTN_VERSION + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # TODO(sage): Refactor the context_attention_fwd kernel so that this + # overhead can be removed + context_lens = torch.empty_like(attn_metadata.seq_lens) + batch_size = len(attn_metadata.query_start_loc) - 1 + assert len(context_lens) == batch_size + for i in range(batch_size): + query_start = attn_metadata.query_start_loc[i] + query_end = attn_metadata.query_start_loc[i + 1] + context_lens[i] = attn_metadata.seq_lens[i] - (query_end - query_start) + + # Compute attention and update output up to `num_actual_tokens`. + context_attention_fwd( + q=query[:num_actual_tokens], + k=key[:num_actual_tokens], + v=value[:num_actual_tokens], + o=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=attn_metadata.block_table, + b_start_loc=attn_metadata.query_start_loc, + b_seq_len=attn_metadata.seq_lens, + b_ctx_len=context_lens, + max_input_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0] + ) + return output \ No newline at end of file From 4bdf7de3d32cec0226827a0d7f973b55ffd0f76d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 00:49:06 +0000 Subject: [PATCH 05/52] more rocm refactoring Signed-off-by: Sage Moore --- vllm/platforms/rocm.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 90a50fbc0b5d..57fb06afa2ee 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -119,16 +119,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - # if scheduler_config.is_multi_step: - # parallel_config.worker_cls = \ - # "vllm.worker.multi_step_worker.MultiStepWorker" - # elif vllm_config.speculative_config: - # parallel_config.worker_cls = \ - # "vllm.spec_decode.spec_decode_worker.create_spec_worker" - # parallel_config.sd_worker_cls = \ - # "vllm.worker.worker.Worker" - # else: - parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + if envs.VLLM_USE_V1: + raise NotImplementedError + else: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" @classmethod def verify_model_arch(cls, model_arch: str) -> None: From e507e308cac3ae7316cf337796b46e08d96446e3 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 01:07:07 +0000 Subject: [PATCH 06/52] more rocm refactoring Signed-off-by: Sage Moore --- vllm/v1/attention/backends/rocm_attn.py | 65 ++++--------------------- 1 file changed, 10 insertions(+), 55 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index b441d0b7bc6f..42e48faf543d 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -1,17 +1,12 @@ """Attention layer with PagedAttention on rocm""" -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type -import numpy as np import torch -import triton -import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.paged_attn import PagedAttention logger = init_logger(__name__) @@ -27,11 +22,11 @@ def get_supported_head_sizes() -> List[int]: @staticmethod def get_name() -> str: - return "FLASH_ATTN_VLLM_V1" + return "ROCM_ATTN_VLLM_V1" @staticmethod - def get_impl_cls() -> Type["FlashAttentionImpl"]: - return FlashAttentionImpl + def get_impl_cls() -> Type["ROCmAttentionImpl"]: + return ROCmAttentionImpl @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: @@ -52,37 +47,7 @@ def get_kv_cache_shape( def use_cascade_attention(*args, **kwargs) -> bool: return False - -@dataclass -class FlashAttentionMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - # For cascade attention. - use_cascade: bool - common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - - -class FlashAttentionImpl(AttentionImpl): +class ROCmAttentionImpl(AttentionImpl): def __init__( self, @@ -99,7 +64,7 @@ def __init__( ) -> None: if blocksparse_params is not None: raise ValueError( - "FlashAttention does not support block-sparse attention.") + "ROCmAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -123,27 +88,15 @@ def __init__( support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() if head_size not in support_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " + f"Head size {head_size} is not supported by ROCmAttention. " f"Supported head sizes are: {support_head_sizes}.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " - "FlashAttentionImpl") - - # if hopper default to FA3, otherwise stick to FA2 for now - # TODO(lucas): profile FA3 on ampere to see if it makes sense to - # use FA3 as default for both - if current_platform.get_device_capability()[0] >= 9: - # self.fa_version = 3 if is_fa_version_supported(3) else 2 - self.fa_version = 2 - else: - self.fa_version = 2 + "ROCmAttentionImpl") - if VLLM_FLASH_ATTN_VERSION is not None: - assert VLLM_FLASH_ATTN_VERSION in [2, 3] - self.fa_version = VLLM_FLASH_ATTN_VERSION def forward( self, @@ -171,6 +124,8 @@ def forward( if attn_metadata is None: # Profiling run. return output + + assert attn_metadata.use_cascade is False # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in From b9ce2590bb45a4277cbde1b6bffefbadfd3193e2 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 01:23:43 +0000 Subject: [PATCH 07/52] hack to fix the multiprocessing isssue Signed-off-by: Sage Moore --- vllm/platforms/rocm.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 8e399f46a39e..d681ab9d2a42 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -99,7 +99,7 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - return torch.cuda.get_device_name(device_id) + return "AMD" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f2cf432acf7d..7707e216017d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -13,9 +13,9 @@ from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.logger import init_logger from vllm.utils import cdiv -# from vllm.vllm_flash_attn import (fa_version_unsupported_reason, -# flash_attn_varlen_func, -# is_fa_version_supported) +from vllm.platforms import current_platform +if current_platform.is_cuda(): + from vllm.vllm_flash_attn import flash_attn_varlen_func logger = init_logger(__name__) From f2cc5e3e42fecbb76f8a2c57bda4f9a0163bdd7c Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 01:28:07 +0000 Subject: [PATCH 08/52] minor print fix Signed-off-by: Sage Moore --- vllm/platforms/rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d681ab9d2a42..9894084f2817 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -79,7 +79,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if "VLLM_USE_V1" in os.environ: - logger.info("Using Flash Attention backend on V1 engine.") + logger.info("Using ROCm Attention backend on V1 engine.") return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): From d6f6c5ccd4c9cad0e1dedb7300b2b37661be3acb Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 01:29:39 +0000 Subject: [PATCH 09/52] remove cruft Signed-off-by: Sage Moore --- csrc/attention/paged_attention_v1.cu | 3 --- csrc/attention/paged_attention_v2.cu | 3 --- 2 files changed, 6 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 96316319af99..9b3a5c4b1014 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -161,9 +161,6 @@ void paged_attention_v1_launcher( case 32: \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ - case 128: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 128, KV_DTYPE); \ - break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 13768b5377c0..9935359e02fb 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -168,9 +168,6 @@ void paged_attention_v2_launcher( case 32: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ - case 128: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 128, KV_DTYPE); \ - break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ From 2bf214a55c461e66a7e53cf35ff6a184a1a71250 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 01:36:53 +0000 Subject: [PATCH 10/52] format Signed-off-by: Sage Moore --- vllm/v1/attention/backends/flash_attn.py | 3 +- vllm/v1/attention/backends/rocm_attn.py | 73 ++++++++++++------------ 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 7707e216017d..77b363679311 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -12,8 +12,9 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.logger import init_logger -from vllm.utils import cdiv from vllm.platforms import current_platform +from vllm.utils import cdiv + if current_platform.is_cuda(): from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 42e48faf543d..223d73f547fa 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 """Attention layer with PagedAttention on rocm""" from typing import Any, Dict, List, Optional, Tuple, Type @@ -5,10 +6,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.logger import init_logger -from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata + logger = init_logger(__name__) @@ -47,6 +49,7 @@ def get_kv_cache_shape( def use_cascade_attention(*args, **kwargs) -> bool: return False + class ROCmAttentionImpl(AttentionImpl): def __init__( @@ -97,7 +100,6 @@ def __init__( "are not implemented for " "ROCmAttentionImpl") - def forward( self, layer: torch.nn.Module, @@ -124,7 +126,7 @@ def forward( if attn_metadata is None: # Profiling run. return output - + assert attn_metadata.use_cascade is False # IMPORTANT! @@ -138,19 +140,19 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # TODO(sage): Refactor the context_attention_fwd kernel so that this # overhead can be removed @@ -158,27 +160,26 @@ def forward( batch_size = len(attn_metadata.query_start_loc) - 1 assert len(context_lens) == batch_size for i in range(batch_size): - query_start = attn_metadata.query_start_loc[i] + query_start = attn_metadata.query_start_loc[i] query_end = attn_metadata.query_start_loc[i + 1] - context_lens[i] = attn_metadata.seq_lens[i] - (query_end - query_start) + context_lens[i] = attn_metadata.seq_lens[i] - (query_end - + query_start) # Compute attention and update output up to `num_actual_tokens`. - context_attention_fwd( - q=query[:num_actual_tokens], - k=key[:num_actual_tokens], - v=value[:num_actual_tokens], - o=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=attn_metadata.block_table, - b_start_loc=attn_metadata.query_start_loc, - b_seq_len=attn_metadata.seq_lens, - b_ctx_len=context_lens, - max_input_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0] - ) - return output \ No newline at end of file + context_attention_fwd(q=query[:num_actual_tokens], + k=key[:num_actual_tokens], + v=value[:num_actual_tokens], + o=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=attn_metadata.block_table, + b_start_loc=attn_metadata.query_start_loc, + b_seq_len=attn_metadata.seq_lens, + b_ctx_len=context_lens, + max_input_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0]) + return output From 11411cb08a8a7f24ab962ae0941214ce797b3895 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 01:58:37 +0000 Subject: [PATCH 11/52] modify requirements files Signed-off-by: Sage Moore --- requirements-build.txt | 2 -- requirements-rocm.txt | 6 ++++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/requirements-build.txt b/requirements-build.txt index 296eb8ca6bd2..34c75a6a6488 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -4,7 +4,5 @@ ninja packaging setuptools>=61 setuptools-scm>=8 -# torch==2.5.1 wheel jinja2 - diff --git a/requirements-rocm.txt b/requirements-rocm.txt index ccc906234177..41bdb897c9de 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -10,3 +10,9 @@ ray >= 2.10.0 peft pytest-asyncio tensorizer>=2.9.0 + +--extra-index-url https://download.pytorch.org/whl/rocm6.2 +torch==2.5.1 +torchvision==0.20.1 +torchaudio==2.5.1 +amdsmi==6.2.4 From c2499bfa41cff5ddd71db451b6c982cbca176964 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 02:02:30 +0000 Subject: [PATCH 12/52] remove basic.py changes Signed-off-by: Sage Moore --- examples/offline_inference/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index c01cd3e22c51..a6e96c0bb433 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -13,7 +13,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="meta-llama/meta-llama-3.1-8b-instruct") +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From cf6f6917397811edec76e9a096455cfc1b279ef7 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 13:39:18 +0000 Subject: [PATCH 13/52] cleanup Signed-off-by: Sage Moore --- vllm/platforms/rocm.py | 4 ++++ vllm/v1/attention/backends/rocm_attn.py | 4 ---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9894084f2817..b624ceb6574a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -99,6 +99,10 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: + # NOTE: When using V1 this function is called when overriding the + # engine args. Calling torch.cuda.get_device_name(device_id) here + # will result in the ROCm context being initialized before other + # processes can be created. return "AMD" @classmethod diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 223d73f547fa..5899a5920922 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -80,10 +80,6 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads From 4505f53cce7de8469b418f5ed0eec5bd4f7aebb9 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Feb 2025 14:04:11 +0000 Subject: [PATCH 14/52] add support for passing in softmax scales to the context_attn_fwd Signed-off-by: Sage Moore --- vllm/attention/ops/prefix_prefill.py | 6 ++++-- vllm/v1/attention/backends/rocm_attn.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 5fca1639363e..362c46a95f32 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -718,7 +718,8 @@ def context_attention_fwd(q, k_scale: torch.Tensor, v_scale: torch.Tensor, alibi_slopes=None, - sliding_window=None): + sliding_window=None, + sm_scale=None): q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 @@ -759,7 +760,8 @@ def context_attention_fwd(q, # round up Lk to a power of 2 - this is required for Triton block size Lk_padded = triton.next_power_of_2(Lk) - sm_scale = 1.0 / (Lq**0.5) + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1] diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5899a5920922..5f3eb37514d8 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -177,5 +177,6 @@ def forward( k_scale=layer._k_scale, v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0]) + sliding_window=self.sliding_window[0], + sm_scale=self.scale) return output From ef9ae863f57980fd528c1608164f96172b87c866 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Feb 2025 14:48:49 +0000 Subject: [PATCH 15/52] added requirements-rocm-build Signed-off-by: Sage Moore --- requirements-rocm-build.txt | 9 +++++++++ requirements-rocm.txt | 6 ------ 2 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 requirements-rocm-build.txt diff --git a/requirements-rocm-build.txt b/requirements-rocm-build.txt new file mode 100644 index 000000000000..4903c3857fb0 --- /dev/null +++ b/requirements-rocm-build.txt @@ -0,0 +1,9 @@ +# Common dependencies +-r requirements-common.txt +-r requirements-build.txt + +--extra-index-url https://download.pytorch.org/whl/rocm6.2 +torch==2.5.1 +torchvision==0.20.1 +torchaudio==2.5.1 +amdsmi==6.2.4 diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 41bdb897c9de..ccc906234177 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -10,9 +10,3 @@ ray >= 2.10.0 peft pytest-asyncio tensorizer>=2.9.0 - ---extra-index-url https://download.pytorch.org/whl/rocm6.2 -torch==2.5.1 -torchvision==0.20.1 -torchaudio==2.5.1 -amdsmi==6.2.4 From a00a2d92aed38068bdc7232d25a03c5ba0aa712b Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Feb 2025 15:20:35 +0000 Subject: [PATCH 16/52] minor setup.py fix Signed-off-by: Sage Moore --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3e2adadf6704..10523b07dc57 100755 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def load_module_from_path(module_name, path): "Building on %s, " "so vLLM may not be able to run correctly", sys.platform) VLLM_TARGET_DEVICE = "empty" -elif (sys.platform.startswith("linux") and torch.version.cuda is None +elif (sys.platform.startswith("linux") and not torch.cuda.is_available() and os.getenv("VLLM_TARGET_DEVICE") is None): # if cuda is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu From afb15f554f3d0ff87e2be2a85e58e26d57b1dacf Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Feb 2025 18:57:00 +0000 Subject: [PATCH 17/52] add batch size back in Signed-off-by: Sage Moore --- requirements-build.txt | 1 + requirements-rocm-build.txt | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/requirements-build.txt b/requirements-build.txt index 34c75a6a6488..fec01caaf25e 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -4,5 +4,6 @@ ninja packaging setuptools>=61 setuptools-scm>=8 +torch==2.5.1 wheel jinja2 diff --git a/requirements-rocm-build.txt b/requirements-rocm-build.txt index 4903c3857fb0..00ae0340fc52 100644 --- a/requirements-rocm-build.txt +++ b/requirements-rocm-build.txt @@ -1,9 +1,16 @@ # Common dependencies -r requirements-common.txt --r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/rocm6.2 torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 + +cmake>=3.26 +ninja +packaging +setuptools>=61 +setuptools-scm>=8 +wheel +jinja2 amdsmi==6.2.4 From 08a25b79decabf635147388e42906de6455a00b3 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Feb 2025 19:22:03 +0000 Subject: [PATCH 18/52] revert setup.py change Signed-off-by: Sage Moore --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 10523b07dc57..3e2adadf6704 100755 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def load_module_from_path(module_name, path): "Building on %s, " "so vLLM may not be able to run correctly", sys.platform) VLLM_TARGET_DEVICE = "empty" -elif (sys.platform.startswith("linux") and not torch.cuda.is_available() +elif (sys.platform.startswith("linux") and torch.version.cuda is None and os.getenv("VLLM_TARGET_DEVICE") is None): # if cuda is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu From 55eb0366dd48bdfcca41d1b15d55d4165a690864 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Feb 2025 19:44:55 +0000 Subject: [PATCH 19/52] update setup.py Signed-off-by: Sage Moore --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 3e2adadf6704..96e04883a819 100755 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ def load_module_from_path(module_name, path): "so vLLM may not be able to run correctly", sys.platform) VLLM_TARGET_DEVICE = "empty" elif (sys.platform.startswith("linux") and torch.version.cuda is None + and torch.version.hip is None and os.getenv("VLLM_TARGET_DEVICE") is None): # if cuda is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu From 95df5712a6f6cc6b7b18b161f1a494a030f1279f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Feb 2025 23:24:21 +0000 Subject: [PATCH 20/52] init Signed-off-by: Sage Moore --- tests/kernels/test_prefix_prefill.py | 4 ++-- vllm/attention/ops/prefix_prefill.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 2184c98525fe..78b08d20bd02 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -100,7 +100,7 @@ def test_contexted_kv_attention( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -333,7 +333,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 5fca1639363e..2e346eca1229 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -72,10 +72,12 @@ def _fwd_kernel( cur_kv_head = cur_head // num_queries_per_kv - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len # start position inside of the query # generally, N goes over kv, while M goes over query_len @@ -511,9 +513,12 @@ def _fwd_kernel_alibi( # cur_batch_seq_len: the length of prompts # cur_batch_ctx_len: the length of prefix # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len block_start_loc = BLOCK_M * start_m @@ -763,6 +768,7 @@ def context_attention_fwd(q, batch, head = b_seq_len.shape[0], q.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1] + assert batch + 1 == len(b_start_loc) grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, # 0 means "disable" From 0bfe435c098a4a102b669e6dc0872eab7c54b929 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 11 Feb 2025 15:10:51 +0000 Subject: [PATCH 21/52] init Signed-off-by: Sage Moore --- tests/kernels/test_prefix_prefill.py | 4 ---- vllm/attention/backends/rocm_flash_attn.py | 1 - vllm/attention/backends/xformers.py | 1 - vllm/attention/ops/paged_attn.py | 2 -- vllm/attention/ops/prefix_prefill.py | 5 ----- 5 files changed, 13 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 78b08d20bd02..c3ac6a37e717 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -154,7 +154,6 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -171,7 +170,6 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -387,7 +385,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -404,7 +401,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02bff57a62b7..f279ce3ea987 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -755,7 +755,6 @@ def forward( prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0b3..ec8e1f2ee5a6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -580,7 +580,6 @@ def forward( prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2c60bd0c38d6..5093fd735034 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -202,7 +202,6 @@ def forward_prefix( block_tables: torch.Tensor, query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], @@ -222,7 +221,6 @@ def forward_prefix( # query_start_loc is (batch_size + 1,) query_start_loc[:-1], seq_lens_tensor, - context_lens, max_query_len, k_scale, v_scale, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 2e346eca1229..1e6aab3e2ffe 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -31,7 +31,6 @@ def _fwd_kernel( v_scale, B_Start_Loc, B_Seqlen, - B_Ctxlen, block_size, x, Out, @@ -468,7 +467,6 @@ def _fwd_kernel_alibi( v_scale, B_Start_Loc, B_Seqlen, - B_Ctxlen, Alibi_slopes, block_size, x, @@ -718,7 +716,6 @@ def context_attention_fwd(q, b_loc, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -788,7 +785,6 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, - b_ctx_len, alibi_slopes, v_cache.shape[3], k_cache.shape[4], @@ -842,7 +838,6 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, - b_ctx_len, v_cache.shape[3], k_cache.shape[4], o, From d2f3c8586e4074540bf9cfebc3fa439f6fe2cb02 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 11 Feb 2025 19:12:14 +0000 Subject: [PATCH 22/52] minor fix Signed-off-by: Sage Moore --- vllm/attention/ops/paged_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 5093fd735034..fd703413db90 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -219,7 +219,7 @@ def forward_prefix( value_cache, block_tables, # query_start_loc is (batch_size + 1,) - query_start_loc[:-1], + query_start_loc, seq_lens_tensor, max_query_len, k_scale, From 9472636ef9625bed936ec194c61a2d9fa1475b1f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 12 Feb 2025 20:21:45 +0000 Subject: [PATCH 23/52] minor fix Signed-off-by: Sage Moore --- vllm/platforms/rocm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9b54f5bd694d..bbfc324be2df 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import os from functools import lru_cache from typing import TYPE_CHECKING, Dict, List, Optional @@ -78,7 +77,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return "vllm.attention.backends.triton_mla.TritonMLABackend" selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) - if "VLLM_USE_V1" in os.environ: + if envs.VLLM_USE_V1: logger.info("Using ROCm Attention backend on V1 engine.") return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" if selected_backend == _Backend.ROCM_FLASH: From 21d8d6aa110c208a27855978d72b99683ad024f8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 12 Feb 2025 21:57:03 +0000 Subject: [PATCH 24/52] update error messages Signed-off-by: Sage Moore --- vllm/platforms/rocm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index bbfc324be2df..d57cce4231dc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -130,13 +130,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: if envs.VLLM_USE_V1: - raise NotImplementedError + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on VLLM V1. Please launch without " + "--num-scheduler-steps.") else: parallel_config.worker_cls = \ "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: if envs.VLLM_USE_V1: - raise NotImplementedError + raise NotImplementedError( + "Speculative decoding is not yet supported on VLLM V1." + ) else: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" From a1cac3dca17fb44a5b28b76d0d94dde4cfb7e8a5 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 13 Feb 2025 18:31:25 +0000 Subject: [PATCH 25/52] init Signed-off-by: Sage Moore --- csrc/ops.h | 8 ++++---- csrc/torch_bindings.cpp | 13 +++++++------ vllm/_custom_ops.py | 1 + 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 70e864cc6a87..60e58b38e5fb 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -178,6 +178,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a); + +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_scale, + torch::Tensor const& input_scale); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, @@ -195,10 +199,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); -void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_scale, - torch::Tensor const& input_scale); - void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 784ded26299e..d3a746dcafb1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -387,6 +387,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool silu_activation," "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + + // Compute NVFP4 block quantized tensor. + ops.def( + "scaled_fp4_quant(Tensor! output, Tensor input," + " Tensor! output_scale, Tensor input_scale) -> ()"); + ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + #endif // Quantized GEMM for GPTQ. @@ -423,12 +430,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); - // Compute NVFP4 block quantized tensor. - ops.def( - "scaled_fp4_quant(Tensor! output, Tensor input," - " Tensor! output_scale, Tensor input_scale) -> ()"); - ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); - // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 67843c177403..3a3f3e00e602 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -787,6 +787,7 @@ def scaled_fp4_quant( two values are packed into a uint8 and float8_e4m3 scaling factors in the sizzled layout. """ + assert not current_platform.is_rocm() assert input.ndim >= 1, ( f'input.ndim needs to be >= 1, but got {input.ndim}.') other_dims = 1 if input.ndim == 1 else -1 From c02b1e6ced36092343abc4ee17efd4fde876911e Mon Sep 17 00:00:00 2001 From: root Date: Fri, 14 Feb 2025 01:12:48 +0000 Subject: [PATCH 26/52] new prefix_prefill Signed-off-by: root --- vllm/attention/ops/prefix_prefill.py | 141 +++++++++++------------- vllm/v1/attention/backends/rocm_attn.py | 51 ++++----- 2 files changed, 92 insertions(+), 100 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 103c408ebbf4..572724996dcb 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -17,7 +17,16 @@ IS_TURING = current_platform.get_device_capability() == (7, 5) if triton.__version__ >= "2.1.0": - + @triton.autotune( + configs=[ + triton.Config({ 'BLOCK_M': block_m, 'BLOCK_N': block_n, + "kpack": 2 }, num_stages=num_stages, \ + num_warps=num_warps) \ + for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ + for num_warps in [4, 8] for num_stages in [1, 2] + ], + key=["BLOCK_SIZE"], + ) @triton.jit def _fwd_kernel( Q, @@ -31,7 +40,6 @@ def _fwd_kernel( v_scale, B_Start_Loc, B_Seqlen, - block_size, x, Out, stride_b_loc_b, @@ -57,11 +65,12 @@ def _fwd_kernel( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, - num_queries_per_kv: int, + num_queries_per_kv: tl.constexpr, IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, ): @@ -83,6 +92,8 @@ def _fwd_kernel( block_start_loc = BLOCK_M * start_m # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) # [D]; starts at 0 @@ -104,51 +115,50 @@ def _fwd_kernel( other=0.0) # [M,D] # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) + for start_n in range(0, cur_batch_ctx_len, BLOCK_SIZE): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) # [N] - # [D,N] + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] off_k = (bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) - # [N,D] - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + # off_k = (bn[None, :] * stride_k_cache_bs + + # cur_kv_head * stride_k_cache_h + + # offs_d[:, None] * stride_k_cache_d + + # offs_bs_n[None, :]) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k) if k_load.dtype.is_fp8(): k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, + qk, float("-inf")) qk *= sm_scale if SLIDING_WINDOW > 0: # (cur_batch_ctx_len + offs_m[:, None]) are the positions of # Q entries in sequence - # (start_n + offs_n[None, :]) are the positions of + # (start_n + offs_bs_n[None, :]) are the positions of # KV entries in sequence # So the condition makes sure each entry in Q only attends # to KV entries not more than SLIDING_WINDOW away. @@ -158,31 +168,19 @@ def _fwd_kernel( # This then makes m_ij contain -inf, which causes NaNs in # exp(). qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] + v_load = tl.load(V_cache + off_v) if v_load.dtype.is_fp8(): v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: @@ -191,8 +189,8 @@ def _fwd_kernel( acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new + l_i = l_i * alpha + l_ij + m_i = m_ij off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) @@ -226,21 +224,12 @@ def _fwd_kernel( < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, @@ -250,9 +239,12 @@ def _fwd_kernel( p = p.to(v.dtype) acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i - l_i = l_i_new - m_i = m_i_new + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] # initialize pointers to output off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + @@ -724,10 +716,6 @@ def context_attention_fwd(q, sm_scale=None): q_dtype_is_f32 = q.dtype is torch.float32 - # need to reduce num. blocks when using fp32 - # due to increased use of GPU shared memory - # if q.dtype is torch.float32: - BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK # Turing does have tensor core for float32 multiplication # use ieee as fallback for triton kernels work. There is also @@ -768,13 +756,18 @@ def context_attention_fwd(q, num_queries_per_kv = q.shape[1] // k.shape[1] assert batch + 1 == len(b_start_loc) - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, # 0 means "disable" if sliding_window is None or sliding_window <= 0: sliding_window = 0 if alibi_slopes is not None: + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) _fwd_kernel_alibi[grid]( q, k, @@ -828,6 +821,8 @@ def context_attention_fwd(q, ) return + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid]( q, k, @@ -840,7 +835,6 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, - v_cache.shape[3], k_cache.shape[4], o, b_loc.stride(0), @@ -868,14 +862,11 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, - BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window, - num_warps=NUM_WARPS, - num_stages=1, ) return diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5f3eb37514d8..01b81f45ff84 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,31 +152,32 @@ def forward( # TODO(sage): Refactor the context_attention_fwd kernel so that this # overhead can be removed - context_lens = torch.empty_like(attn_metadata.seq_lens) - batch_size = len(attn_metadata.query_start_loc) - 1 - assert len(context_lens) == batch_size - for i in range(batch_size): - query_start = attn_metadata.query_start_loc[i] - query_end = attn_metadata.query_start_loc[i + 1] - context_lens[i] = attn_metadata.seq_lens[i] - (query_end - - query_start) + # context_lens = torch.empty_like(attn_metadata.seq_lens) + # batch_size = len(attn_metadata.query_start_loc) - 1 + # assert len(context_lens) == batch_size + # for i in range(batch_size): + # query_start = attn_metadata.query_start_loc[i] + # query_end = attn_metadata.query_start_loc[i + 1] + # context_lens[i] = attn_metadata.seq_lens[i] - (query_end - + # query_start) # Compute attention and update output up to `num_actual_tokens`. - context_attention_fwd(q=query[:num_actual_tokens], - k=key[:num_actual_tokens], - v=value[:num_actual_tokens], - o=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=attn_metadata.block_table, - b_start_loc=attn_metadata.query_start_loc, - b_seq_len=attn_metadata.seq_lens, - b_ctx_len=context_lens, - max_input_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + context_attention_fwd( + q=query[:num_actual_tokens], + k=key[:num_actual_tokens], + v=value[:num_actual_tokens], + o=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=attn_metadata.block_table, + b_start_loc=attn_metadata.query_start_loc, + b_seq_len=attn_metadata.seq_lens, + # b_ctx_len=context_lens, + max_input_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale) return output From 540b2869e18b26e1572cd6e87d19504fcf906bbe Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 20 Feb 2025 02:01:39 +0000 Subject: [PATCH 27/52] dwordx4 for k and v from cache Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 572724996dcb..f374d1a3b813 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,7 +16,7 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -if triton.__version__ >= "2.1.0": +if triton.__version__ >= "3.2.0": @triton.autotune( configs=[ triton.Config({ 'BLOCK_M': block_m, 'BLOCK_N': block_n, @@ -25,7 +25,8 @@ for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ for num_warps in [4, 8] for num_stages in [1, 2] ], - key=["BLOCK_SIZE"], + key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", + "BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"] ) @triton.jit def _fwd_kernel( @@ -40,7 +41,7 @@ def _fwd_kernel( v_scale, B_Start_Loc, B_Seqlen, - x, + x: tl.constexpr, Out, stride_b_loc_b, stride_b_loc_s, @@ -59,7 +60,7 @@ def _fwd_kernel( stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d, - stride_k_cache_bl, + stride_k_cache_bl: tl.constexpr, stride_k_cache_x, stride_v_cache_bs, stride_v_cache_h, From 1d9eb5044e26fba6960c043a14145acc296015b3 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 20 Feb 2025 02:13:57 +0000 Subject: [PATCH 28/52] follow up merge with main Signed-off-by: Aleksandr Malyshev --- vllm/platforms/rocm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 426c790ba669..a4f18cbfc587 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -141,7 +141,9 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - return torch.cuda.get_device_name(device_id) + physical_device_id = device_id_to_physical_device_id(device_id) + handle = amdsmi_get_processor_handles()[physical_device_id] + return amdsmi_get_gpu_asic_info(handle)["market_name"] @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: From 9eb55667d2b43e49ae2db52bb0aa41176dc8ae58 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Sat, 22 Feb 2025 02:12:54 +0000 Subject: [PATCH 29/52] different stages for different loops Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 111 +++++++++++++----------- vllm/v1/attention/backends/rocm_attn.py | 17 ++-- 2 files changed, 67 insertions(+), 61 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index f374d1a3b813..6d948f44be92 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -19,62 +19,66 @@ if triton.__version__ >= "3.2.0": @triton.autotune( configs=[ - triton.Config({ 'BLOCK_M': block_m, 'BLOCK_N': block_n, - "kpack": 2 }, num_stages=num_stages, \ - num_warps=num_warps) \ + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, "kpack": 2, \ + "num_stages_cache": num_stages_cache, \ + "num_stages_request": num_stages_request }, \ + num_warps=num_warps) \ for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ - for num_warps in [4, 8] for num_stages in [1, 2] + for num_warps in [4, 8] for num_stages_cache in [1, 2] \ + for num_stages_request in [1, 2] ], - key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", + key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", \ "BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"] ) @triton.jit def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - ): + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl. + constexpr, # head size padded to a power of 2 + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_stages_cache: tl.constexpr, + num_stages_request: tl.constexpr): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -122,7 +126,8 @@ def _fwd_kernel( dtype=tl.float32) # [M,D] # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_SIZE): + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + num_stages=num_stages_cache): start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + @@ -204,7 +209,9 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) # compute query against itself (with causal mask) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + num_stages=num_stages_request): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5f3eb37514d8..409145f4c302 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,14 +152,14 @@ def forward( # TODO(sage): Refactor the context_attention_fwd kernel so that this # overhead can be removed - context_lens = torch.empty_like(attn_metadata.seq_lens) - batch_size = len(attn_metadata.query_start_loc) - 1 - assert len(context_lens) == batch_size - for i in range(batch_size): - query_start = attn_metadata.query_start_loc[i] - query_end = attn_metadata.query_start_loc[i + 1] - context_lens[i] = attn_metadata.seq_lens[i] - (query_end - - query_start) + # context_lens = torch.empty_like(attn_metadata.seq_lens) + # batch_size = len(attn_metadata.query_start_loc) - 1 + # assert len(context_lens) == batch_size + # for i in range(batch_size): + # query_start = attn_metadata.query_start_loc[i] + # query_end = attn_metadata.query_start_loc[i + 1] + # context_lens[i] = attn_metadata.seq_lens[i] - (query_end - + # query_start) # Compute attention and update output up to `num_actual_tokens`. context_attention_fwd(q=query[:num_actual_tokens], @@ -172,7 +172,6 @@ def forward( b_loc=attn_metadata.block_table, b_start_loc=attn_metadata.query_start_loc, b_seq_len=attn_metadata.seq_lens, - b_ctx_len=context_lens, max_input_len=attn_metadata.max_query_len, k_scale=layer._k_scale, v_scale=layer._v_scale, From 9bc9217625589cc1510578dfbd2e3bf1f603e957 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 25 Feb 2025 00:11:52 +0000 Subject: [PATCH 30/52] unroll factors tunning Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 6d948f44be92..3780f0e7e1dc 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -20,12 +20,12 @@ @triton.autotune( configs=[ triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, "kpack": 2, \ - "num_stages_cache": num_stages_cache, \ - "num_stages_request": num_stages_request }, \ + "num_unroll_cache": num_unroll_cache, \ + "num_unroll_request": num_unroll_request }, \ num_warps=num_warps) \ for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ - for num_warps in [4, 8] for num_stages_cache in [1, 2] \ - for num_stages_request in [1, 2] + for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ + for num_unroll_request in [1, 2] ], key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", \ "BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"] @@ -72,13 +72,14 @@ def _fwd_kernel( IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl. - constexpr, # head size padded to a power of 2 + # head size padded to a power of 2 + BLOCK_DMODEL_PADDED: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, - num_stages_cache: tl.constexpr, - num_stages_request: tl.constexpr): + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr): + cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -127,7 +128,7 @@ def _fwd_kernel( # compute query against context (no causal mask here) for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - num_stages=num_stages_cache): + loop_unroll_factor=num_unroll_cache): start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + @@ -211,7 +212,7 @@ def _fwd_kernel( # compute query against itself (with causal mask) for start_n in tl.range(0, \ block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - num_stages=num_stages_request): + loop_unroll_factor=num_unroll_request): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + From 2b8444857ed815c6b807ed75beb4bc76172358b2 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 25 Feb 2025 00:20:52 +0000 Subject: [PATCH 31/52] linter Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 3780f0e7e1dc..f7df1263738c 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -72,8 +72,8 @@ def _fwd_kernel( IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size - # head size padded to a power of 2 - BLOCK_DMODEL_PADDED: tl.constexpr, + # head size padded to a power of 2, + BLOCK_DMODEL_PADDED: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, From 1067508cf1a76074a7e898b04442468a071293f4 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 25 Feb 2025 20:21:37 +0000 Subject: [PATCH 32/52] default prefix_prefill for triton lower than 3.2, NV case Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 506 +++++++++++++++++++++------ 1 file changed, 392 insertions(+), 114 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index f7df1263738c..25b6cff2a666 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,70 +16,55 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -if triton.__version__ >= "3.2.0": - @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, "kpack": 2, \ - "num_unroll_cache": num_unroll_cache, \ - "num_unroll_request": num_unroll_request }, \ - num_warps=num_warps) \ - for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ - for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ - for num_unroll_request in [1, 2] - ], - key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", \ - "BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"] - ) +if triton.__version__ >= "2.1.0": + @triton.jit def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - # head size padded to a power of 2, - BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr): - + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -98,8 +83,6 @@ def _fwd_kernel( block_start_loc = BLOCK_M * start_m # initialize offsets - # [BLOCK_SIZE]; starts at 0 - offs_bs_n = tl.arange(0, BLOCK_SIZE) # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) # [D]; starts at 0 @@ -121,51 +104,51 @@ def _fwd_kernel( other=0.0) # [M,D] # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] # compute query against context (no causal mask here) - for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - loop_unroll_factor=num_unroll_cache): - start_n = tl.multiple_of(start_n, BLOCK_SIZE) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) - # [D,BLOCK_SIZE] + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] off_k = (bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) - # off_k = (bn[None, :] * stride_k_cache_bs + - # cur_kv_head * stride_k_cache_h + - # offs_d[:, None] * stride_k_cache_d + - # offs_bs_n[None, :]) - - # [BLOCK_SIZE,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - offs_bs_n[:, None] * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k) + # [N,D] + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] if k_load.dtype.is_fp8(): k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load - qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, - qk, float("-inf")) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) qk *= sm_scale if SLIDING_WINDOW > 0: # (cur_batch_ctx_len + offs_m[:, None]) are the positions of # Q entries in sequence - # (start_n + offs_bs_n[None, :]) are the positions of + # (start_n + offs_n[None, :]) are the positions of # KV entries in sequence # So the condition makes sure each entry in Q only attends # to KV entries not more than SLIDING_WINDOW away. @@ -175,19 +158,31 @@ def _fwd_kernel( # This then makes m_ij contain -inf, which causes NaNs in # exp(). qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, - qk, -10000) + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) # -- compute m_ij, p, l_ij - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.exp2(qk - m_ij[:, None]) - - l_ij = tl.sum(p, 1) - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] # update acc - v_load = tl.load(V_cache + off_v) + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] if v_load.dtype.is_fp8(): v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: @@ -196,8 +191,8 @@ def _fwd_kernel( acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij + l_i = l_i_new + m_i = m_i_new off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) @@ -210,9 +205,7 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) # compute query against itself (with causal mask) - for start_n in tl.range(0, \ - block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - loop_unroll_factor=num_unroll_request): + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + @@ -233,12 +226,21 @@ def _fwd_kernel( < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.exp2(qk - m_ij[:, None]) - + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, @@ -248,12 +250,9 @@ def _fwd_kernel( p = p.to(v.dtype) acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - acc = acc / l_i[:, None] + l_i = l_i_new + m_i = m_i_new # initialize pointers to output off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + @@ -265,6 +264,225 @@ def _fwd_kernel( (offs_m[:, None] < cur_batch_query_len)) return + if triton.__version__ >= "3.2.0": + @triton.autotune( + configs=[ + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, \ + "kpack": 2, \ + "num_unroll_cache": num_unroll_cache, \ + "num_unroll_request": num_unroll_request }, \ + num_warps=num_warps) \ + for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ + for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ + for num_unroll_request in [1, 2] + ], + key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", \ + "BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"] + ) + @triton.jit + def _fwd_kernel( + Q, K, V, K_cache, V_cache, B_Loc, sm_scale, k_scale, v_scale, + B_Start_Loc, B_Seqlen, x: tl.constexpr, Out, stride_b_loc_b, + stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, + stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, stride_k_cache_bs, + stride_k_cache_h, stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, stride_k_cache_x, + stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, + stride_v_cache_bl, num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # off_k = (bn[None, :] * stride_k_cache_bs + + # cur_kv_head * stride_k_cache_h + + # offs_d[:, None] * stride_k_cache_d + + # offs_bs_n[None, :]) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) + < cur_batch_ctx_len, qk, float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) + < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v_load = tl.load(V_cache + off_v) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), + qk, float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) + < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update acc + v = tl.load( + v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return + @triton.jit def _fwd_kernel_flash_attn_v2( Q, @@ -830,8 +1048,64 @@ def context_attention_fwd(q, ) return - grid = lambda META: (batch, head, - triton.cdiv(max_input_len, META["BLOCK_M"])) + if triton.__version__ >= "3.2.0": + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + ) + return + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) _fwd_kernel[grid]( q, k, @@ -844,6 +1118,7 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, + v_cache.shape[3], k_cache.shape[4], o, b_loc.stride(0), @@ -871,11 +1146,14 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] - BLOCK_SIZE=v_cache.shape[3], num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window, + num_warps=NUM_WARPS, + num_stages=1, ) return From 506b0c49ea5afb62215b41012b22267b20f6ead6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Mar 2025 17:51:01 +0000 Subject: [PATCH 33/52] original softmax restored to get back accuracy Signed-off-by: root --- vllm/attention/ops/prefix_prefill.py | 57 +++++++++++++++---------- vllm/v1/attention/backends/rocm_attn.py | 7 ++- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 25b6cff2a666..c585645fcdc7 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -272,7 +272,7 @@ def _fwd_kernel( "num_unroll_cache": num_unroll_cache, \ "num_unroll_request": num_unroll_request }, \ num_warps=num_warps) \ - for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ + for block_m in [32, 64] for block_n in [32, 64, 128] \ for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ for num_unroll_request in [1, 2] ], @@ -356,10 +356,6 @@ def _fwd_kernel( ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) - # off_k = (bn[None, :] * stride_k_cache_bs + - # cur_kv_head * stride_k_cache_h + - # offs_d[:, None] * stride_k_cache_d + - # offs_bs_n[None, :]) # [BLOCK_SIZE,D] off_v = (bn[:, None] * stride_v_cache_bs + @@ -395,12 +391,22 @@ def _fwd_kernel( < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.exp2(qk - m_ij[:, None]) - - l_ij = tl.sum(p, 1) - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] # update acc v_load = tl.load(V_cache + off_v) @@ -412,8 +418,8 @@ def _fwd_kernel( acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij + l_i = l_i_new + m_i = m_i_new off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) @@ -450,12 +456,21 @@ def _fwd_kernel( < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.exp2(qk - m_ij[:, None]) - + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] # update acc v = tl.load( v_ptrs + @@ -466,12 +481,10 @@ def _fwd_kernel( p = p.to(v.dtype) acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij + l_i = l_i_new + m_i = m_i_new - acc = acc / l_i[:, None] # initialize pointers to output off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 0f3fabf05fc2..5c7d759b1812 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -9,7 +9,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, FlashAttentionMetadataBuilder) logger = init_logger(__name__) @@ -49,6 +50,10 @@ def get_kv_cache_shape( def use_cascade_attention(*args, **kwargs) -> bool: return False + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + class ROCmAttentionImpl(AttentionImpl): From 05c3d3b10ffcb938c147940f3461a52535ded039 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Mon, 10 Mar 2025 23:49:30 +0000 Subject: [PATCH 34/52] adaptation to ibm kernel Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 516ae934f610..df61113aa193 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -299,7 +299,7 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr): + num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -314,6 +314,9 @@ def _fwd_kernel( cur_batch_in_all_start_index) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + if SKIP_DECODE and cur_batch_query_len == 1: + return + # start position inside of the query # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m @@ -1121,6 +1124,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, ) return From e76f27f4e82ceaeddffcfa496b8da8cdfba8e114 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 11 Mar 2025 06:05:54 +0000 Subject: [PATCH 35/52] softmax computation correction Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 56 ++++++++++------------------ 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index df61113aa193..b01e105a8013 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -277,7 +277,7 @@ def _fwd_kernel( "num_unroll_cache": num_unroll_cache, \ "num_unroll_request": num_unroll_request }, \ num_warps=num_warps) \ - for block_m in [32, 64] for block_n in [32, 64, 128] \ + for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ for num_unroll_request in [1, 2] ], @@ -398,23 +398,12 @@ def _fwd_kernel( (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, -10000) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] - - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] # update acc v_load = tl.load(V_cache + off_v) @@ -426,8 +415,8 @@ def _fwd_kernel( acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new + l_i = l_i * alpha + l_ij + m_i = m_ij off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) @@ -463,22 +452,13 @@ def _fwd_kernel( offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + # update acc v = tl.load( v_ptrs + @@ -490,8 +470,10 @@ def _fwd_kernel( acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # update m_i and l_i - l_i = l_i_new - m_i = m_i_new + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] # initialize pointers to output off_o = ( From da80a03d8150ee89ad9ac176d4d15eb102136134 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 13 Mar 2025 22:25:43 +0000 Subject: [PATCH 36/52] a comment for triton version Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index b01e105a8013..f6fee13085ee 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -269,6 +269,10 @@ def _fwd_kernel( (offs_m[:, None] < cur_batch_query_len)) return + # On triton versions lower 3.2 the assertion: + # Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && + # "mma -> mma layout conversion is only supported on Ampere"' failed. + # is observed if triton.__version__ >= "3.2.0": @triton.autotune( configs=[ From 81277c8cb64c2c44c223b12089e2ffaa7f8051a1 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Wed, 19 Mar 2025 15:43:17 +0000 Subject: [PATCH 37/52] kpack is not supported on NVidia triton Signed-off-by: maleksan85 --- vllm/attention/ops/prefix_prefill.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index f6fee13085ee..a35c5f40deec 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -277,9 +277,10 @@ def _fwd_kernel( @triton.autotune( configs=[ triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, \ - "kpack": 2, \ "num_unroll_cache": num_unroll_cache, \ - "num_unroll_request": num_unroll_request }, \ + "num_unroll_request": num_unroll_request } + + {"kpack": 2} + if current_platform.is_rocm() else {}, \ num_warps=num_warps) \ for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ From a4000dfed0c551e69c5a18246ba65c5fddcc7a89 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Wed, 19 Mar 2025 17:34:04 +0000 Subject: [PATCH 38/52] kpack is not supported on NVidia triton Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a35c5f40deec..cb1820fd6446 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -278,9 +278,9 @@ def _fwd_kernel( configs=[ triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, \ "num_unroll_cache": num_unroll_cache, \ - "num_unroll_request": num_unroll_request } - + {"kpack": 2} - if current_platform.is_rocm() else {}, \ + "num_unroll_request": num_unroll_request } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ num_warps=num_warps) \ for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ From a027e5cee22f8dffdf3903ce969debba6c4cc2f8 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 27 Mar 2025 17:48:56 +0000 Subject: [PATCH 39/52] reduced space of autotuning Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 208 +++++++++++++++++++++++---- 1 file changed, 181 insertions(+), 27 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index cb1820fd6446..85a875725990 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -274,37 +274,174 @@ def _fwd_kernel( # "mma -> mma layout conversion is only supported on Ampere"' failed. # is observed if triton.__version__ >= "3.2.0": + + def my_bench(kernel, quantiles): + timings = triton.testing.do_bench(kernel, quantiles=quantiles, \ + warmup=100, rep=10000) + # print(f"{timings=}") + return timings + @triton.autotune( configs=[ - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, \ - "num_unroll_cache": num_unroll_cache, \ - "num_unroll_request": num_unroll_request } | \ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 2 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 16, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 2 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=8, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=2), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, \ + "num_unroll_cache": 2, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, \ + "num_unroll_cache": 2, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 2, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 1 } | \ ({"kpack": 2} \ if current_platform.is_rocm() else {}), \ - num_warps=num_warps) \ - for block_m in [32, 64, 128] for block_n in [32, 64, 128] \ - for num_warps in [4, 8] for num_unroll_cache in [1, 2] \ - for num_unroll_request in [1, 2] + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 16, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 2 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 2, \ + "num_unroll_request": 2 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 2 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, \ + "num_unroll_cache": 1, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, \ + "num_unroll_cache": 2, \ + "num_unroll_request": 4 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=8, \ + num_stages=1), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 4, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=8, \ + num_stages=1), ], - key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", \ - "BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"] + do_bench=lambda kernel, quantiles: my_bench(kernel, quantiles), + key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] ) @triton.jit - def _fwd_kernel( - Q, K, V, K_cache, V_cache, B_Loc, sm_scale, k_scale, v_scale, - B_Start_Loc, B_Seqlen, x: tl.constexpr, Out, stride_b_loc_b, - stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, - stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, stride_k_cache_bs, - stride_k_cache_h, stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, stride_k_cache_x, - stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, - stride_v_cache_bl, num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr): + def _my_fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -375,7 +512,15 @@ def _fwd_kernel( cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + offs_bs_n[:, None] * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) if k_load.dtype.is_fp8(): k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) @@ -411,7 +556,15 @@ def _fwd_kernel( acc = acc * alpha[:, None] # update acc - v_load = tl.load(V_cache + off_v) + if start_n + BLOCK_SIZE > cur_batch_ctx_len: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + if v_load.dtype.is_fp8(): v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: @@ -1065,7 +1218,7 @@ def context_attention_fwd(q, if triton.__version__ >= "3.2.0": grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) - _fwd_kernel[grid]( + _my_fwd_kernel[grid]( q, k, v, @@ -1112,6 +1265,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, + MAX_Q_LEN=triton.next_power_of_2(max_input_len), ) return From 81c2739ddbea4179fd1532697b8d0beb0879c1eb Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 8 Apr 2025 01:57:27 +0000 Subject: [PATCH 40/52] giving up on autotune and selecting one config Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 109 ++------------------------- 1 file changed, 6 insertions(+), 103 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 6c3f23b347da..9ecc73ffa9f1 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -277,117 +277,20 @@ def _fwd_kernel( def my_bench(kernel, quantiles): timings = triton.testing.do_bench(kernel, quantiles=quantiles, \ - warmup=100, rep=10000) - # print(f"{timings=}") + warmup=100, rep=1000) + print(f"{timings=}") return timings @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 2 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 16, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 2 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=8, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=2), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, \ - "num_unroll_cache": 2, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, \ - "num_unroll_cache": 2, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=2), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=2), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 2, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=2), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 16, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 2 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 2, \ - "num_unroll_request": 2 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 2 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, \ - "num_unroll_cache": 1, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, \ - "num_unroll_cache": 2, \ - "num_unroll_request": 4 } | \ - ({"kpack": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=8, \ - num_stages=1), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "waves_per_eu": 2, \ "num_unroll_cache": 4, \ "num_unroll_request": 1 } | \ ({"kpack": 2} \ if current_platform.is_rocm() else {}), \ - num_warps=8, \ - num_stages=1), + num_warps=4, \ + num_stages=1) ], do_bench=lambda kernel, quantiles: my_bench(kernel, quantiles), key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] @@ -1267,7 +1170,7 @@ def context_attention_fwd(q, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, MAX_Q_LEN=triton.next_power_of_2(max_input_len), - ) + MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) return # need to reduce num. blocks when using fp32 From 5a17950690aabb96ba7984002b55f1e4fcb79b7e Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 8 Apr 2025 18:07:15 +0000 Subject: [PATCH 41/52] fixing test with only to ROCM waves per eu and max_seq_len None Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index f117ee25c153..0fd3dcb9b787 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -284,10 +284,9 @@ def my_bench(kernel, quantiles): @triton.autotune( configs=[ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "waves_per_eu": 2, \ "num_unroll_cache": 4, \ "num_unroll_request": 1 } | \ - ({"kpack": 2} \ + ({"kpack": 2, "waves_per_eu": 2} \ if current_platform.is_rocm() else {}), \ num_warps=4, \ num_stages=1) @@ -1120,6 +1119,7 @@ def context_attention_fwd(q, return if triton.__version__ >= "3.2.0": + max_seq_len = 0 if max_seq_len is None else max_seq_len grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _my_fwd_kernel[grid]( From 5d9a929eed9ab9c8ae955dbfa31d617e5e916430 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 03:54:32 +0000 Subject: [PATCH 42/52] renaming kernel Signed-off-by: root Signed-off-by: <> --- vllm/attention/ops/prefix_prefill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 0fd3dcb9b787..95a3e4d41d34 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -295,7 +295,7 @@ def my_bench(kernel, quantiles): key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] ) @triton.jit - def _my_fwd_kernel(Q, + def _fwd_kernel_v2(Q, K, V, K_cache, @@ -1122,7 +1122,7 @@ def context_attention_fwd(q, max_seq_len = 0 if max_seq_len is None else max_seq_len grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) - _my_fwd_kernel[grid]( + _fwd_kernel_v2[grid]( q, k, v, From 27f044b1485d5fd4bf1cbc90776e791a96d7bc18 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 10 Apr 2025 22:43:35 +0000 Subject: [PATCH 43/52] clean up and fix for failed kernel tests Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 1938 +++++++++++--------------- 1 file changed, 797 insertions(+), 1141 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 95a3e4d41d34..eb113c98a606 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,1170 +16,774 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -if triton.__version__ >= "2.1.0": - - @triton.jit - def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], - dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) # [N] - # [D,N] - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - # [N,D] - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] - - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) - < SLIDING_WINDOW, qk, -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len)) +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 4, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2, "waves_per_eu": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1) + ], + key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +) +@triton.jit +def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - # On triton versions lower 3.2 the assertion: - # Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && - # "mma -> mma layout conversion is only supported on Ampere"' failed. - # is observed - if triton.__version__ >= "3.2.0": - - def my_bench(kernel, quantiles): - timings = triton.testing.do_bench(kernel, quantiles=quantiles, \ - warmup=100, rep=1000) - print(f"{timings=}") - return timings - - @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 4, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2, "waves_per_eu": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1) - ], - do_bench=lambda kernel, quantiles: my_bench(kernel, quantiles), - key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] - ) - @triton.jit - def _fwd_kernel_v2(Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr, - SKIP_DECODE: tl.constexpr, - MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): - - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [BLOCK_SIZE]; starts at 0 - offs_bs_n = tl.arange(0, BLOCK_SIZE) - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], - dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - loop_unroll_factor=num_unroll_cache): - start_n = tl.multiple_of(start_n, BLOCK_SIZE) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) - # [D,BLOCK_SIZE] - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - - # [BLOCK_SIZE,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - offs_bs_n[:, None] * stride_v_cache_bl) - - if start_n + BLOCK_SIZE > cur_batch_ctx_len: - k_load = tl.load( - K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - else: - k_load = tl.load(K_cache + off_k) - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_bs_n[None, :]) - < cur_batch_ctx_len, qk, float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_bs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_bs_n[None, :]) - < SLIDING_WINDOW, qk, -10000) - - # compute running maximum - m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.exp(m_i - m_ij) - acc = acc * alpha[:, None] - - # update acc - if start_n + BLOCK_SIZE > cur_batch_ctx_len: - v_load = tl.load( - V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - else: - v_load = tl.load(V_cache + off_v) - - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in tl.range(0, \ - block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - loop_unroll_factor=num_unroll_request): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_query_len), other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), - qk, float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) - < SLIDING_WINDOW, qk, -10000) - - # compute running maximum - m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.exp(m_i - m_ij) - acc = acc * alpha[:, None] - - # update acc - v = tl.load( - v_ptrs + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_query_len), other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len)) - return - - @triton.jit - def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] - < cur_batch_seq_len - cur_batch_ctx_len, + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) - @triton.jit - def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - @torch.inference_mode() - def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False): - - q_dtype_is_f32 = q.dtype is torch.float32 - - # Turing does have tensor core for float32 multiplication - # use ieee as fallback for triton kernels work. There is also - # warning on vllm/config.py to inform users this fallback - # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = current_platform.fp8_dtype() - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) - - k_cache = k_cache.view(target_dtype) - v_cache = v_cache.view(target_dtype) - - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - if sm_scale is None: - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - assert batch + 1 == len(b_start_loc) - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - if alibi_slopes is not None: - # need to reduce num. blocks when using fp32 - # due to increased use of GPU shared memory - # if q.dtype is torch.float32: - BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK - # batch, head, - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - _fwd_kernel_alibi[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - k_scale, - v_scale, - b_start_loc, - b_seq_len, - alibi_slopes, - v_cache.shape[3], - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - IN_PRECISION=IN_PRECISION, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, - SKIP_DECODE=skip_decode, - num_warps=NUM_WARPS, - num_stages=1, - ) - return - - if triton.__version__ >= "3.2.0": - max_seq_len = 0 if max_seq_len is None else max_seq_len - grid = lambda META: (batch, head, - triton.cdiv(max_input_len, META["BLOCK_M"])) - _fwd_kernel_v2[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - k_scale, - v_scale, - b_start_loc, - b_seq_len, - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - BLOCK_SIZE=v_cache.shape[3], - num_queries_per_kv=num_queries_per_kv, - IN_PRECISION=IN_PRECISION, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - SLIDING_WINDOW=sliding_window, - SKIP_DECODE=skip_decode, - MAX_Q_LEN=triton.next_power_of_2(max_input_len), - MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) - return - + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK # batch, head, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - _fwd_kernel[grid]( + _fwd_kernel_alibi[grid]( q, k, v, @@ -1191,6 +795,7 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, + alibi_slopes, v_cache.shape[3], k_cache.shape[4], o, @@ -1225,9 +830,60 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + MAX_Q_LEN=triton.next_power_of_2(max_input_len), + MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) + return From cfd60c9e5a0f4167ba5f29176d52de7c560d6b89 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 10 Apr 2025 23:12:57 +0000 Subject: [PATCH 44/52] clean up and fix for failed kernel tests Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/prefix_prefill.py | 1652 +++++++++++++------------- 1 file changed, 835 insertions(+), 817 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index eb113c98a606..d223d097b7e6 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,774 +16,847 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 4, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2, "waves_per_eu": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1) - ], - key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] -) -@triton.jit -def _fwd_kernel(Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr, - SKIP_DECODE: tl.constexpr, - MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): - - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [BLOCK_SIZE]; starts at 0 - offs_bs_n = tl.arange(0, BLOCK_SIZE) - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - loop_unroll_factor=num_unroll_cache): - start_n = tl.multiple_of(start_n, BLOCK_SIZE) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) - # [D,BLOCK_SIZE] - off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - - # [BLOCK_SIZE,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - offs_bs_n[:, None] * stride_v_cache_bl) - - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: - k_load = tl.load( - K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - else: - k_load = tl.load(K_cache + off_k) - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_bs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) - - # compute running maximum - m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.exp(m_i - m_ij) - acc = acc * alpha[:, None] - - # update acc - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: - v_load = tl.load( - V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - else: - v_load = tl.load(V_cache + off_v) - - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in tl.range(0, \ - block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - loop_unroll_factor=num_unroll_request): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, +if triton.__version__ >= "2.1.0": + @triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 4, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2, "waves_per_eu": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1) + ], + key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] + ) + @triton.jit + def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, - qk, -10000) - - # compute running maximum - m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.exp(m_i - m_ij) - acc = acc * alpha[:, None] - - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, + qk, float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) - return - - -@triton.jit -def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) + < SLIDING_WINDOW, qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), other=0.0) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - -@triton.jit -def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SKIP_DECODE: tl.constexpr, -): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) return - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) - return - - -@torch.inference_mode() -def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False): - - q_dtype_is_f32 = q.dtype is torch.float32 - - # Turing does have tensor core for float32 multiplication - # use ieee as fallback for triton kernels work. There is also - # warning on vllm/config.py to inform users this fallback - # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = current_platform.fp8_dtype() - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) - - k_cache = k_cache.view(target_dtype) - v_cache = v_cache.view(target_dtype) - - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - if sm_scale is None: - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - assert batch + 1 == len(b_start_loc) - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - if alibi_slopes is not None: - # need to reduce num. blocks when using fp32 - # due to increased use of GPU shared memory - # if q.dtype is torch.float32: - BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK - # batch, head, - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - _fwd_kernel_alibi[grid]( + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SKIP_DECODE=skip_decode, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( q, k, v, @@ -795,8 +868,6 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, - alibi_slopes, - v_cache.shape[3], k_cache.shape[4], o, b_loc.stride(0), @@ -824,66 +895,13 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, - BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, - num_warps=NUM_WARPS, - num_stages=1, - ) + MAX_Q_LEN=triton.next_power_of_2(max_input_len), + MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) return - - max_seq_len = 0 if max_seq_len is None else max_seq_len - grid = lambda META: (batch, head, - triton.cdiv(max_input_len, META["BLOCK_M"])) - _fwd_kernel[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - k_scale, - v_scale, - b_start_loc, - b_seq_len, - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] - BLOCK_SIZE=v_cache.shape[3], - num_queries_per_kv=num_queries_per_kv, - IN_PRECISION=IN_PRECISION, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - SLIDING_WINDOW=sliding_window, - SKIP_DECODE=skip_decode, - MAX_Q_LEN=triton.next_power_of_2(max_input_len), - MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) - return From 0a26697caaa6f29105f0faf74a38990404febee0 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 10 Apr 2025 23:40:12 +0000 Subject: [PATCH 45/52] clean up and fix for failed kernel tests Signed-off-by: Aleksandr Malyshev --- vllm/attention/ops/paged_attn.py | 44 +- vllm/attention/ops/prefix_prefill.py | 1652 +++++++++++++------------- 2 files changed, 839 insertions(+), 857 deletions(-) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 827c3041a682..dc91d4762b5b 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -8,9 +8,6 @@ from vllm import _custom_ops as ops from vllm.triton_utils import HAS_TRITON -if HAS_TRITON: - from vllm.attention.ops.prefix_prefill import context_attention_fwd - # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 @@ -210,25 +207,28 @@ def forward_prefix( ) -> torch.Tensor: output = torch.empty_like(query) max_seq_len = None - context_attention_fwd( - query, - key, - value, - output, - kv_cache_dtype, - key_cache, - value_cache, - block_tables, - # query_start_loc is (batch_size + 1,) - query_start_loc, - seq_lens_tensor, - max_seq_len, - max_query_len, - k_scale, - v_scale, - alibi_slopes, - sliding_window, - ) + + if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc, + seq_lens_tensor, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) return output @staticmethod diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index d223d097b7e6..eb113c98a606 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,847 +16,774 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -if triton.__version__ >= "2.1.0": - @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 4, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2, "waves_per_eu": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1) - ], - key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] - ) - @triton.jit - def _fwd_kernel(Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr, - SKIP_DECODE: tl.constexpr, - MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): - - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [BLOCK_SIZE]; starts at 0 - offs_bs_n = tl.arange(0, BLOCK_SIZE) - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], - dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - loop_unroll_factor=num_unroll_cache): - start_n = tl.multiple_of(start_n, BLOCK_SIZE) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) - # [D,BLOCK_SIZE] - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - - # [BLOCK_SIZE,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - offs_bs_n[:, None] * stride_v_cache_bl) - - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: - k_load = tl.load( - K_cache + off_k, +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ + "num_unroll_cache": 4, \ + "num_unroll_request": 1 } | \ + ({"kpack": 2, "waves_per_eu": 2} \ + if current_platform.is_rocm() else {}), \ + num_warps=4, \ + num_stages=1) + ], + key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +) +@triton.jit +def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, mask=dim_mask[:, None] & - ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - else: - k_load = tl.load(K_cache + off_k) - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, - qk, float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_bs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, - qk, -10000) - - # compute running maximum - m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.exp(m_i - m_ij) - acc = acc * alpha[:, None] - - # update acc - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: - v_load = tl.load( - V_cache + off_v, + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=dim_mask[None, :] & - ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - else: - v_load = tl.load(V_cache + off_v) - - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in tl.range(0, \ - block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - loop_unroll_factor=num_unroll_request): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) - < SLIDING_WINDOW, qk, -10000) - - # compute running maximum - m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, axis=1) - alpha = tl.exp(m_i - m_ij) - acc = acc * alpha[:, None] - - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len)) - return + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) - @triton.jit - def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - @triton.jit - def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - @torch.inference_mode() - def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False): - - q_dtype_is_f32 = q.dtype is torch.float32 - - # Turing does have tensor core for float32 multiplication - # use ieee as fallback for triton kernels work. There is also - # warning on vllm/config.py to inform users this fallback - # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = current_platform.fp8_dtype() - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) - - k_cache = k_cache.view(target_dtype) - v_cache = v_cache.view(target_dtype) - - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - if sm_scale is None: - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - assert batch + 1 == len(b_start_loc) - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - if alibi_slopes is not None: - # need to reduce num. blocks when using fp32 - # due to increased use of GPU shared memory - # if q.dtype is torch.float32: - BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK - # batch, head, - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - _fwd_kernel_alibi[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - k_scale, - v_scale, - b_start_loc, - b_seq_len, - alibi_slopes, - v_cache.shape[3], - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - IN_PRECISION=IN_PRECISION, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, - SKIP_DECODE=skip_decode, - num_warps=NUM_WARPS, - num_stages=1, - ) - return - - max_seq_len = 0 if max_seq_len is None else max_seq_len - grid = lambda META: (batch, head, - triton.cdiv(max_input_len, META["BLOCK_M"])) - _fwd_kernel[grid]( + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + _fwd_kernel_alibi[grid]( q, k, v, @@ -868,6 +795,8 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, + alibi_slopes, + v_cache.shape[3], k_cache.shape[4], o, b_loc.stride(0), @@ -895,13 +824,66 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] - BLOCK_SIZE=v_cache.shape[3], num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, - SLIDING_WINDOW=sliding_window, + BLOCK_N=BLOCK, SKIP_DECODE=skip_decode, - MAX_Q_LEN=triton.next_power_of_2(max_input_len), - MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) + num_warps=NUM_WARPS, + num_stages=1, + ) return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + MAX_Q_LEN=triton.next_power_of_2(max_input_len), + MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) + return From 35a6e4902aefe52c14b18639450e301bee977471 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Fri, 11 Apr 2025 22:52:55 +0000 Subject: [PATCH 46/52] got rid of autotuner and get stable runs right from the first iteration Signed-off-by: maleksan85 --- vllm/attention/ops/prefix_prefill.py | 42 ++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index eb113c98a606..84ef5e0f43b5 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,18 +16,23 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ - "num_unroll_cache": 4, \ - "num_unroll_request": 1 } | \ - ({"kpack": 2, "waves_per_eu": 2} \ - if current_platform.is_rocm() else {}), \ - num_warps=4, \ - num_stages=1) - ], - key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] -) + +# Autotune is great idea but not yet performant nor stable. With Autotune, first +# run is slower than others. Probably because graph capture doesn't run kernel, +# probably by some other reason. So want to leave "template" of what could be +# autotuned in this kernel to get even better perf. +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ +# "num_unroll_cache": 4, \ +# "num_unroll_request": 1 } | \ +# ({"kpack": 2, "waves_per_eu": 2} \ +# if current_platform.is_rocm() else {}), \ +# num_warps=4, \ +# num_stages=1) +# ], +# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +# ) @triton.jit def _fwd_kernel(Q, K, @@ -836,6 +841,10 @@ def context_attention_fwd(q, ) return + extra_kargs = {} + if current_platform.is_rocm(): + extra_kargs = {"kpack": 2, "waves_per_eu": 2} + max_seq_len = 0 if max_seq_len is None else max_seq_len grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) @@ -884,6 +893,11 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, - MAX_Q_LEN=triton.next_power_of_2(max_input_len), - MAX_CTX_LEN=triton.next_power_of_2(max_seq_len)) + BLOCK_M=128, + BLOCK_N=64, + num_unroll_cache=4, + num_unroll_request=1, + num_warps=4, + num_stages=1, + **extra_kargs) return From 6d5b3f2e7ee1ab5b31ee63b800f392862cac8648 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Sat, 12 Apr 2025 15:59:46 +0000 Subject: [PATCH 47/52] restoring paged attn as there is no autotuning anymore and that will no be error during start Signed-off-by: maleksan85 --- vllm/attention/ops/paged_attn.py | 44 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index dc91d4762b5b..827c3041a682 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -8,6 +8,9 @@ from vllm import _custom_ops as ops from vllm.triton_utils import HAS_TRITON +if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd + # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 @@ -207,28 +210,25 @@ def forward_prefix( ) -> torch.Tensor: output = torch.empty_like(query) max_seq_len = None - - if HAS_TRITON: - from vllm.attention.ops.prefix_prefill import context_attention_fwd - context_attention_fwd( - query, - key, - value, - output, - kv_cache_dtype, - key_cache, - value_cache, - block_tables, - # query_start_loc is (batch_size + 1,) - query_start_loc, - seq_lens_tensor, - max_seq_len, - max_query_len, - k_scale, - v_scale, - alibi_slopes, - sliding_window, - ) + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc, + seq_lens_tensor, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) return output @staticmethod From 7140d1aa79d7b9a8c0dc8eb01e5cbe89c07ec017 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Sun, 13 Apr 2025 03:05:39 +0000 Subject: [PATCH 48/52] poking test rerun as one failed and seems not because of this change Signed-off-by: maleksan85 --- vllm/attention/ops/prefix_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 84ef5e0f43b5..5fbd24256399 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -841,11 +841,11 @@ def context_attention_fwd(q, ) return + max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): extra_kargs = {"kpack": 2, "waves_per_eu": 2} - max_seq_len = 0 if max_seq_len is None else max_seq_len grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid]( From ba078b6034ff7a911aa0a6de9a80ea2be5699495 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Mon, 14 Apr 2025 22:43:21 +0000 Subject: [PATCH 49/52] comment correction Signed-off-by: maleksan85 --- vllm/attention/ops/prefix_prefill.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 5fbd24256399..a8c8d8409620 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -17,10 +17,9 @@ IS_TURING = current_platform.get_device_capability() == (7, 5) -# Autotune is great idea but not yet performant nor stable. With Autotune, first -# run is slower than others. Probably because graph capture doesn't run kernel, -# probably by some other reason. So want to leave "template" of what could be -# autotuned in this kernel to get even better perf. +# Here's an example autotuner config for this kernel. This config does provide +# a performance improvement, but dramatically increases first call latency in +# triton 3.2. Because of this tradeoff, it's currently commented out. # @triton.autotune( # configs=[ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ From 617ef08f816e28cb60db97f92210b1fd967cefad Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Tue, 15 Apr 2025 16:04:44 +0000 Subject: [PATCH 50/52] dot operation in triton doesn't support k to be 8 so increasing block size to most commonly used Signed-off-by: maleksan85 --- tests/core/block/e2e/test_correctness.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index e9b537ed5150..9e8e315d87b1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -195,15 +195,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 2, "max_num_seqs": 2, }, { - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 3, "max_num_seqs": 2, }, { - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 256, "max_num_seqs": 10, }]) From 771ad9ed9a6f5eb607a8edd3812e832a111a3c0e Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Tue, 15 Apr 2025 17:41:58 +0000 Subject: [PATCH 51/52] to kick CIs again Async Engine, Inputs, Utils, Worker Test seems flaky Signed-off-by: maleksan85 --- vllm/attention/ops/prefix_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a8c8d8409620..d2629988ca2f 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -840,11 +840,11 @@ def context_attention_fwd(q, ) return - max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): extra_kargs = {"kpack": 2, "waves_per_eu": 2} + max_seq_len = 0 if max_seq_len is None else max_seq_len grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid]( From b6bf36505a38c0e34d79d7b2561c5b521aca907a Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Tue, 15 Apr 2025 18:37:04 +0000 Subject: [PATCH 52/52] to kick CIs again Signed-off-by: maleksan85 --- vllm/attention/ops/prefix_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index d2629988ca2f..a8c8d8409620 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -840,11 +840,11 @@ def context_attention_fwd(q, ) return + max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): extra_kargs = {"kpack": 2, "waves_per_eu": 2} - max_seq_len = 0 if max_seq_len is None else max_seq_len grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid](