diff --git a/Makefile b/Makefile index 41446b65a..24db77e4b 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ MAX_JOBS := 64 SHELL := /bin/bash -.PHONY: all build clean format dev rocm rocm-upstream pyupdate nightly bm-rocm +.PHONY: all build clean format dev rocm rocm-upstream pyupdate nightly bm-rocm spelling all: build @@ -65,3 +65,7 @@ else format: python -m black --check --verbose scripts ibm-triton-lib third_party endif + +spelling: + codespell ./ibm-triton-lib ./triton-dejavu ./scripts + diff --git a/ibm-triton-lib/ibm_triton_lib/backend/__init__.py b/ibm-triton-lib/ibm_triton_lib/backend/__init__.py index 12880f27e..cb36f5f4c 100644 --- a/ibm-triton-lib/ibm_triton_lib/backend/__init__.py +++ b/ibm-triton-lib/ibm_triton_lib/backend/__init__.py @@ -19,11 +19,4 @@ def register(): """Register the triton attention platform.""" - - VLLM_USE_V1 = int(os.environ.get("VLLM_USE_V1", "0")) - - # backend only works with v0 currently - if VLLM_USE_V1: - return None - else: - return "ibm_triton_lib.backend.platform.TritonPlatform" + return "ibm_triton_lib.backend.platform.TritonPlatform" diff --git a/ibm-triton-lib/ibm_triton_lib/backend/platform.py b/ibm-triton-lib/ibm_triton_lib/backend/platform.py index b50c6190d..6f807aa37 100644 --- a/ibm-triton-lib/ibm_triton_lib/backend/platform.py +++ b/ibm-triton-lib/ibm_triton_lib/backend/platform.py @@ -61,4 +61,6 @@ def get_attn_backend_cls( use_v1, use_mla, ) -> str: + if not envs.VLLM_USE_V1: + raise RuntimeError("vllm-triton-backend plugin only supports vLLM V1") return "ibm_triton_lib.backend.triton_attn.TritonAttentionBackend" diff --git a/ibm-triton-lib/ibm_triton_lib/backend/triton_attn.py b/ibm-triton-lib/ibm_triton_lib/backend/triton_attn.py index 5ac66b411..63fd74838 100644 --- a/ibm-triton-lib/ibm_triton_lib/backend/triton_attn.py +++ b/ibm-triton-lib/ibm_triton_lib/backend/triton_attn.py @@ -26,503 +26,244 @@ """ +import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch -import os -import vllm.envs as envs - -# TODO: currently needed for advance_step and reshape_and_cache from vllm import _custom_ops as ops -from vllm.triton_utils import HAS_TRITON - -# TODO: better idea? -assert HAS_TRITON - +from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, - AttentionLayer, ) -from vllm.attention.backends.utils import CommonAttentionState, CommonMetadataBuilder +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode +from vllm.attention.ops.paged_attn import PagedAttention +from ibm_triton_lib.kernels import unified_attention from vllm.logger import init_logger -from vllm.config import VllmConfig from vllm.platforms import current_platform - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -# TODO: use triton reshape and cache -# from vllm.attention.ops.reshape_and_cache import ( -# reshape_and_cache as triton_reshape_and_cache, -# ) -from ibm_triton_lib.kernels import ( - paged_attention_2d, - paged_attention_3d, - prefill_flash_attention, +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + make_local_attention_virtual_batches, ) -from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner -class TritonAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "zrl-triton-attn" - - @staticmethod - def get_impl_cls() -> Type["TritonAttentionImpl"]: - return TritonAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return TritonAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["TritonAttentionMetadataBuilder"]: - return TritonAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (2, num_blocks, block_size * num_kv_heads * head_size) +logger = logging.getLogger(__name__) @dataclass -class TritonAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - +class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| + # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - _cached_prefill_metadata: Optional["TritonAttentionMetadata"] = None - _cached_decode_metadata: Optional["TritonAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["TritonAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.block_tables is not None - - self._cached_prefill_metadata = TritonAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[: self.num_prefill_tokens], - multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[: self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[: self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=( - None - if self.query_start_loc is None - else self.query_start_loc[: self.num_prefills + 1] - ), - seq_start_loc=( - None - if self.seq_start_loc is None - else self.seq_start_loc[: self.num_prefills + 1] - ), - context_lens_tensor=( - None - if self.context_lens_tensor is None - else self.context_lens_tensor[: self.num_prefills] - ), - block_tables=self.block_tables[: self.num_prefills], - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["TritonAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = TritonAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens :], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills :], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills :], - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables, - ) - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - def advance_step( + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): + full_cudagraph_supported: ClassVar[bool] = True + + def __init__( self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False, + runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, ): - """ - Update metadata in-place to advance one decode step. - """ - - assert not turn_prefills_into_decodes, ( - "Multi Step Chunked prefill is not supported with triton_attn yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter." + self.runner = runner + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + attn_metadata = self.build(0, common_attn_metadata) + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + attn_metadata.seq_lens.fill_(1) + return attn_metadata + + def build( + self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True ) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + ( + seqlens_q_local_np, + virt_q_cu_seqlens_np, + virt_k_seqlens_np, + virt_block_table_tensor, + ) = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[: num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table_tensor, + self.block_size, + ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True + ) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True + ) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + + local_attn_metadata = TritonAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=virt_block_table_tensor, + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=None, + ) - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs,) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs,) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1,) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1,) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries,) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - # TODO: add triton implementation - ops.advance_step_flashattn( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.runner.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.runner.device + ) + suffix_kv_lens = self.runner.seq_lens_np[:num_reqs] - common_prefix_len + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.runner.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = TritonAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, ) + return attn_metadata + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata + ) -> bool: + # Full CUDA Graph always supported + return True -class TritonAttentionMetadataBuilder(CommonMetadataBuilder[TritonAttentionMetadata]): - - _metadata_cls = TritonAttentionMetadata - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: Optional[List[int]], - make_attn_mask: bool = True, -) -> List[torch.Tensor]: - attn_biases = [] - if seq_lens: - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = ( - torch.empty((1, seq_len, seq_len), dtype=bias.dtype) - .fill_(-torch.inf) - .triu_(diagonal=1) - .to(alibi_slopes.device) - ) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) - - return attn_biases - - -def _get_seq_len_block_table_args( - attn_metadata: TritonAttentionMetadata, - attn_type: str, -) -> tuple: - """ - The particular choice of sequence-length - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths - Encoder attn -> select encoder sequence lengths fields - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensors for query and key - * Appropriate max sequence-length scalar - """ - - partial_prefix_sum = 0 - if attn_type == AttentionType.ENCODER: - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - [0] - + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens - ], - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype, - ) - causal_mask = False - - # No block tables associated with encoder attention - return ( - query_seq_start_loc, - attn_metadata.max_encoder_seq_len, - query_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_lens, - causal_mask, - ) - elif attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - [0] - + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens - ], - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype, - ) - max_seq_len = attn_metadata.max_prefill_seq_len - causal_mask = True - - return ( - query_seq_start_loc, - max_seq_len, - query_seq_start_loc, - max_seq_len, - attn_metadata.seq_lens, - causal_mask, - ) - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_start_loc = torch.tensor( - [0] - + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens - ], - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype, - ) - partial_prefix_sum = 0 - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - key_seq_start_loc = torch.tensor( - [0] - + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens - ], - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype, - ) - causal_mask = False - - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return ( - query_start_loc, - attn_metadata.max_prefill_seq_len, - key_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.seq_lens, - causal_mask, - ) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") +class TritonAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True -class TritonAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + @staticmethod + def get_name() -> str: + return "TRITON_ATTN_VLLM_V1" - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. + @staticmethod + def get_impl_cls() -> type["TritonAttentionImpl"]: + return TritonAttentionImpl - The prompts might have different lengths, while the generation tokens - always have length 1. + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return TritonAttentionMetadata - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) - |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| - |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ + @staticmethod + def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: + return TritonAttentionMetadataBuilder + + +class TritonAttentionImpl(AttentionImpl): def __init__( self, @@ -530,22 +271,17 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], + alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, + blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, + use_irope: bool = False, ) -> None: if blocksparse_params is not None: - raise ValueError("TritonAttention does not support blocksparse attention.") - - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - self.logits_soft_cap = 0.0 - else: - self.logits_soft_cap = logits_soft_cap - self.attn_type = attn_type + raise ValueError("TritonAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -553,466 +289,200 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.sliding_window = ( - (sliding_window, sliding_window) if sliding_window is not None else (-1, -1) - ) + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.use_irope = use_irope - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - supported_head_sizes = self.get_supported_head_sizes() - if head_size not in supported_head_sizes: + support_head_sizes = TritonAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}." + f"Head size {head_size} is not supported by TritonAttention. " + f"Supported head sizes are: {support_head_sizes}." ) - # self.attn_func = triton_wrapper_forward - logger.debug("Using Triton Prefill attention") - # TODO - if self.sliding_window != (-1, -1): - logger.warning( - "Triton FA does not currently support " "sliding window attention. " + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonAttentionImpl" ) - logger.debug("Using Triton Paged attention") - - def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - tokens, n_kv_heads, head_dim = x.shape - return ( - x[:, :, None, :] - .expand(tokens, n_kv_heads, n_rep, head_dim) - .reshape(tokens, n_kv_heads * n_rep, head_dim) - ) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 120, 128, 192, 256] - - def split_kv_cache( - self, - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size, -1, 1) - # TODO: maybe enable cache alginment as in cuda? - # x = 16 // kv_cache.element_size() - # key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) - return key_cache, value_cache - - def write_to_paged_cache( - self, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - ) -> None: - # TODO use integrate triton reshape and cache - # triton_reshape_and_cache( - # key, value, key_cache, value_cache, slot_mapping.flatten() - # ) - ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - # used by cache_engine - # TODO: add triton implementation and make non-static method - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - # used by cache_engine - # TODO: add triton implementation and make non-static method - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) + self.fp8_dtype = current_platform.fp8_dtype() + self.force_prefill_decode_attn = False + # TODO: logger.info in a plugin is suppressed? + logger.warning_once("Using vllm-triton-backend attention PLUGIN V1.") def forward( self, - layer: AttentionLayer, + layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: TritonAttentionMetadata, + attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * TritonAttentionImpl.forward() may be invoked for both self- and - cross-attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) + """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ + assert output is not None, "Output tensor must be provided." - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TritonAttentionImpl" + ) - if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: - key_cache, value_cache = self.split_kv_cache( + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + use_prefill_decode_attn = self.force_prefill_decode_attn + num_actual_tokens = attn_metadata.num_actual_tokens + + if use_prefill_decode_attn: + key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size ) + else: + key_cache, value_cache = kv_cache.unbind(0) - if key is not None and value is not None: - # Reshape the input keys and values and store them in the - # cache. If kv_cache is not provided, the new key and value - # tensors are not cached. This happens during the initial - # memory profiling run. - self.write_to_paged_cache( + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + if use_prefill_decode_attn: + PagedAttention.write_to_paged_cache( key, value, key_cache, value_cache, - ( - attn_metadata.slot_mapping - if self.attn_type != AttentionType.ENCODER_DECODER - else attn_metadata.cross_slot_mapping - ), + attn_metadata.slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) - - if self.attn_type != AttentionType.ENCODER: - num_prefill_tokens = attn_metadata.num_prefill_tokens - else: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - - if ( - key is not None - and value is not None - and self.attn_type != AttentionType.ENCODER_DECODER - ): - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - # normal attention and DECODER - if self.attn_type == AttentionType.DECODER and ( - kv_cache.numel() == 0 - or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0 - ): - ( - query_seq_start_loc, - query_max_seq_len, - key_seq_start_loc, - key_max_seq_len, - seq_lens, - causal_mask, - ) = ( - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - attn_metadata.seq_lens, - True, - ) - # prefix-enabled attention and ENCODER/ENCODER_DECODER - else: - ( - query_seq_start_loc, - query_max_seq_len, - key_seq_start_loc, - key_max_seq_len, - seq_lens, - causal_mask, - ) = _get_seq_len_block_table_args(prefill_meta, self.attn_type) - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # triton attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - attn_masks = None - if self.alibi_slopes is not None: - # FIXME - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - attn_metadata.seq_lens, - make_attn_mask=False, - ) # type: ignore - - out = prefill_flash_attention( - query, - key, - value, - sm_scale=self.scale, - causal=True, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - ) - - assert output[:num_prefill_tokens].shape == out.shape - if output.shape[0] > num_prefill_tokens: - output[:num_prefill_tokens] = out - else: - output = out else: - # prefix-enabled attention - output[:num_prefill_tokens] = self.forward_prefix( - query, + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, - self.kv_cache_dtype, key_cache, value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], + attn_metadata.slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - decode_output = self.forward_decode( - decode_query, - key_cache, - value_cache, - ( - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER - else decode_meta.cross_block_tables - ), - ( - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER - else decode_meta.encoder_seq_lens_tensor - ), - ( - decode_meta.max_decode_seq_len - if self.attn_type != AttentionType.ENCODER_DECODER - else decode_meta.max_encoder_seq_len - ), - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - # print(decode_output) - output[num_prefill_tokens:] = decode_output - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + num_tokens, num_heads, head_size = query.shape + assert ( + layer._q_scale == 1.0 + ), "A non 1.0 q_scale is not currently supported." + if not current_platform.is_rocm(): + # Skip Q quantization on ROCm, since dequantizing back to + # f32 in the attention kernel is not supported. + query, _ = ops.scaled_fp8_quant( + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) + query = query.reshape((num_tokens, num_heads, head_size)) - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache_dtype: str, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_query_len: int, - alibi_slopes: Optional[torch.Tensor], - sliding_window: Optional[int], - k_scale: float, - v_scale: float, - ) -> torch.Tensor: - output = torch.empty_like(query) - context_attention_fwd( - query, - key, - value, - output, - kv_cache_dtype, - key_cache, - value_cache, - block_tables, - # query_start_loc is (batch_size + 1,) - query_start_loc[:-1], - seq_lens_tensor, - context_lens, - max_query_len, - k_scale, - v_scale, - alibi_slopes, - sliding_window, + use_local_attn = ( + self.use_irope and attn_metadata.local_attn_metadata is not None ) - return output - def forward_decode( - self, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, - ) -> torch.Tensor: - if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: - # use blocksparse paged attention - block_size = value_cache.size(-1) - assert ( - blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 - ), ( - f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables." + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + seqused_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + if use_prefill_decode_attn: + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, ) - output = torch.empty_like(query) - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - - batch_size = num_seqs - num_query_heads = num_heads - num_queries_per_kv = num_query_heads // num_kv_heads - - # TODO - # NOTE(ngl): Since vLLM uses cuda graph for decode as default, - # and the example workloads for the cuda graph capture are all >128, - # it will always use 3d, currently. - use_3d = max_seq_len > 128 - # print(f"use 3d triton paged attention: {use_3d}") - - if not use_3d: - paged_attention_2d( - output, - query, - key_cache, - value_cache, - scale, - k_scale, - v_scale, - kv_cache_dtype, - block_tables, - seq_lens, - alibi_slopes, - block_size, - batch_size, - num_query_heads, - num_queries_per_kv, - head_size, - ) else: - paged_attention_3d( - output, - query, - key_cache, - value_cache, - scale, - k_scale, - v_scale, - kv_cache_dtype, - block_tables, - seq_lens, - alibi_slopes, - block_size, - batch_size, - num_query_heads, - num_queries_per_kv, - head_size, + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) return output diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/utils.py b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/utils.py index 5b15ff9bf..9be93b3d9 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/utils.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/fused_gqa_paged/utils.py @@ -20,7 +20,7 @@ def compute_split_l(L, BLOCK_L, P=1, device=None): # no need to further split L return 1 - # Find mininum num_splits that will result in enough triton programs + # Find minimum num_splits that will result in enough triton programs # TODO: does num_splits need to be power of 2? num_splits = 1 split_size = L diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py index 4bc7c0b97..eb03b53c2 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py @@ -72,7 +72,7 @@ def cdiv_fn(x, y): @triton_dejavu.jitcache( - # remove cache_lock if dyanmic cache mode should be used + # remove cache_lock if dynamic cache mode should be used cache_lock=global_cache_lock, # list of `tl.constexpr` that should be used as cache index check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"], diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_flash_attention.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_flash_attention.py index d8496f22b..41c922d23 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/triton_flash_attention.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_flash_attention.py @@ -1416,7 +1416,7 @@ def triton_wrapper_forward_prefill( # number of compute units available NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count - # TODO: test persitent + # TODO: test persistent if metadata.persistent is not None: grid = lambda META: ( min( diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py index e246e4f95..79d77d2ca 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py @@ -11,10 +11,7 @@ import triton import triton.language as tl -from vllm.logger import init_logger -from vllm.triton_utils.jit_cache import jitcache - -logger = init_logger(__name__) +import triton_dejavu @triton.jit @@ -53,7 +50,7 @@ def find_seq_idx( return left - 1 -@jitcache( +@triton_dejavu.jitcache( check_keys=[], check_specialization=["num_seqs"], assume_const=[ diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 534a97b00..edb6d7dab 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -39,18 +39,17 @@ ref_prefix_prefill, ref_reshape_and_cache_flash, ref_reshape_and_cache, + ref_paged_attn, ) from torch_utils import get_gpu_label, end2end_bench from ibm_triton_lib.utils.triton_utils import get_runtime_label from roofline.proton_viewer import parse +# let the three most used variables be overwritten separately STORE_TEST_RESULT_PATH = os.environ.get("STORE_TEST_RESULT_PATH", None) MY_IUT = [ e for e in os.environ.get("MY_IUT", "").split(",") if len(e) > 0 ] # my implementations under test (IUT) -MY_MAX_VALUES = [ - e for e in os.environ.get("MY_MAX_VALUES", "").split(",") if len(e) > 0 -] MY_METHODS = [e for e in os.environ.get("MY_METHODS", "").split(",") if len(e) > 0] @@ -78,67 +77,51 @@ class BenchmarkMode(Enum): TORCH_COMPILE = 3 +class BatchComposition(Enum): + DEC_PRE = 0 + PRE_DEC = 1 + ALTERNATING = 2 + + +impl_translate = {i.name: i.value for i in Implementation} +method_translate = {i.name: i.value for i in BenchmarkMode} +batch_comp_translate = {i.name: i.value for i in BatchComposition} + +dtype_translate = { + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "float8_e4m3fn": torch.float8_e4m3fn, + "float5_e5m2": torch.float8_e5m2, +} + # DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.float16] SEEDS = [0] BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] -# BATCH_SIZES = [128] -# BATCH_SIZES = [64] -# BATCH_SIZES = [4] -# BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256] -# BATCH_SIZES = [1, 2, 3, 4, 5, 7, 8, 12, 16, 32, 64, 128] - # order: num_query_heads, num_kv_heads -# NUM_HEADS = [(32, 32), (32, 8)] -NUM_HEADS = [(32, 8)] -# NUM_HEADS = [(32, 32)] - +NUM_HEADS = [(32, 32), (32, 8)] SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] -# SEQUENCE_LENGTHS = [8] -# SEQUENCE_LENGTHS= [128] -# SEQUENCE_LENGTHS = [16, 17] -# SEQUENCE_LENGTHS = [2048] -# SEQUENCE_LENGTHS = [4096] -# SEQUENCE_LENGTHS = [4321] -# SEQUENCE_LENGTHS = [16, 128, 512, 1024, 2048, 4096] -# SEQUENCE_LENGTHS = [24, 128, 512, 1024, 2048, 4096] - -# CONTEXT_LENGTHS = [16, 128, 512, 1024, 2048, 4096] -# QUERY_LENGTHS = [1, 16, 128, 512, 1024, 2048, 4096] -# QUERY_LENGTHS = [1, 1024] -# PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] -# PREFIX_PREFILL_SHARE_OF_DECODE = [1.0] -# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0] -PREFIX_PREFILL_SHARE_OF_DECODE = [1.0, 0.3] -# PREFIX_PREFILL_SHARE_OF_DECODE = [0.8] -# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] -# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] -# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0] -# HEAD_SIZES_FLASH = [32, 64, 128] # only powers of 2! HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 # head_size * head_numbers = hidden_size -# BLOCK_SIZES = [8, 16, 32] BLOCK_SIZES = [16] # if torch.version.hip: # BLOCK_SIZES = [16, 128] -# NUM_BLOCKS = [8, 16, 32] NUM_BLOCKS = [4321] # "arbitrary values for testing..." # options most likely not used...but keep for now? CAUSAL_FLASH = [True] # vLLM only needs causal=True PROMPT_PATTERNS = [] -# PROMPT_PATTERNS.append([1.0]) -# PROMPT_PATTERNS.append([1.0, 0.4, 0.5, 1.0, 0.2]) +PROMPT_PATTERNS.append([1.0]) PROMPT_PATTERNS.append([0.1, 0.4, 0.5, 1.0, 0.2]) -impl_translate = {i.name: i.value for i in Implementation} -method_translate = {i.name: i.value for i in BenchmarkMode} - IMPLEMENTATION_UT = [ Implementation.TRITON_2D, Implementation.TRITON_3D, @@ -157,85 +140,113 @@ class BenchmarkMode(Enum): MAX_VALUES = [1.0] BENCHMARK_MODES = [BenchmarkMode.CUDA_EVENTS, BenchmarkMode.CUDA_GRAPHS] -if os.getenv("NGL_FULL_TEST", "0") == "1": - # IMPLEMENTATION_UT = [ - # Implementation.VLLM_CUDA_V1, - # Implementation.ZRL_TRITON, - # Implementation.ZRL_TRITON_3D, - # ] - BENCHMARK_MODES = [ - BenchmarkMode.CUDA_EVENTS, - BenchmarkMode.END2END, - BenchmarkMode.CUDA_GRAPHS, - ] - # SEQUENCE_LENGTHS = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] - SEQUENCE_LENGTHS = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] - # BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256] - BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] -elif os.getenv("NGL_FULL_TEST", "0") == "2": - # IMPLEMENTATION_UT = [ - # Implementation.VLLM_CUDA_V1, - # Implementation.ZRL_TRITON, - # Implementation.ZRL_TRITON_3D, - # ] - BENCHMARK_MODES = [ - BenchmarkMode.CUDA_EVENTS, - BenchmarkMode.END2END, - BenchmarkMode.CUDA_GRAPHS, - ] - SEQUENCE_LENGTHS = [32, 44, 54, 64, 511, 512, 513, 648, 912, 1024, 2025, 3030, 4096] - # SEQUENCE_LENGTHS = [6321] - BATCH_SIZES = [ - 1, - 2, - 4, - 8, - 16, - 28, - 32, - 54, - 64, - 96, - 128, - ] - # BATCH_SIZES = [102] - MAX_VALUES = [1.0] +device = "cuda:0" +gpu_name = get_gpu_label() + +do_benchmarks = True +# do_benchmarks = False +quantiles = [0.5, 0.2, 0.8] +# should maybe also be controlled via env variable +force_dump_dataframes = False +enforce_numerical_correctness = True +# enforce_numerical_correctness = False +do_profiling = False # will add overhead to kernel runtime measured via CUDA_EVENTS +store_hatchet = False +add_triton_dejavu_envs = True +debug_flag = False + + +test_setup_vars = [ + "SEEDS", + "BATCH_SIZES", + "NUM_HEADS", + "SEQUENCE_LENGTHS", + "PREFIX_PREFILL_SHARE_OF_DECODE", + "PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL", # "PREFIX_PREFILL_BATCH_COMPOSITION", + "HEAD_SIZES", + "BLOCK_SIZES", + "NUM_BLOCKS", + "CAUSAL_FLASH", + "PROMPT_PATTERNS", + "MAX_VALUES", +] +# "BENCHMARK_MODES", "IMPLEMENTATION_UT" ] +debug_env_vars = [ + "STORE_TEST_RESULT_PATH", + "TEST_ALLOW_INCORRECT", + "TRITON_BACKEND_DEBUG", +] + +# need to deal with envfile here +if len(sys.argv) >= 1: + envfile_name = None + for ca in sys.argv[1:]: + if ".conf" in ca: + envfile_name = ca + break + if envfile_name is not None: + from dotenv import dotenv_values + import json + + envfile_path = os.path.abspath(envfile_name) + print(f"\nApplied test config: {envfile_path}") + env_setting = dotenv_values(envfile_path) + # filter allowed, convert all to lists + env_setting_filtered = { + k: json.loads(env_setting[k]) for k in test_setup_vars if k in env_setting + } + # update all + globals().update(env_setting_filtered) + # fix enums + if "DTYPES" in env_setting: + sl = json.loads(env_setting["DTYPES"]) + DTYPES = [dtype_translate[v] for v in sl] + if "PREFIX_PREFILL_BATCH_COMPOSITION" in env_setting: + sl = json.loads(env_setting["PREFIX_PREFILL_BATCH_COMPOSITION"]) + PREFIX_PREFILL_BATCH_COMPOSITION = [ + BatchComposition(batch_comp_translate[v]) for v in sl + ] + # iut and methods could come here too, or are overwritten below + if "IMPLEMENTATION_UT" in env_setting: + sl = json.loads(env_setting["IMPLEMENTATION_UT"]) + IMPLEMENTATION_UT = [Implementation(impl_translate[v]) for v in sl] + if "BENCHMARK_MODES" in env_setting: + sl = json.loads(env_setting["BENCHMARK_MODES"]) + BENCHMARK_MODES = [BenchmarkMode(method_translate[v]) for v in sl] + + # set additional flags + if "STORE_TEST_RESULT_PATH" in env_setting and STORE_TEST_RESULT_PATH is None: + STORE_TEST_RESULT_PATH = env_setting["STORE_TEST_RESULT_PATH"] + if ( + "TEST_ALLOW_INCORRECT" in env_setting + and env_setting["TEST_ALLOW_INCORRECT"] == "1" + ): + enforce_numerical_correctness = False + if ( + "TRITON_BACKEND_DEBUG" in env_setting + and env_setting["TRITON_BACKEND_DEBUG"] == "1" + ): + debug_flag = True if len(MY_IUT) > 0: IMPLEMENTATION_UT = [] for ci_value in MY_IUT: IMPLEMENTATION_UT.append(Implementation(impl_translate[ci_value])) -if len(MY_MAX_VALUES) > 0: - MAX_VALUES = [] - for cm_value in MY_MAX_VALUES: - MAX_VALUES.append(float(cm_value)) if len(MY_METHODS) > 0: BENCHMARK_MODES = [] for cb_value in MY_METHODS: BENCHMARK_MODES.append(BenchmarkMode(method_translate[cb_value])) +# only overwrite the .conf file if the environment variable is present! +if "TEST_ALLOW_INCORRECT" in os.environ: + enforce_numerical_correctness = os.environ["TEST_ALLOW_INCORRECT"] == "1" +if "TRITON_BACKEND_DEBUG" in os.environ: + debug_flag = os.environ["TRITON_BACKEND_DEBUG"] == "1" for varlen_p in PROMPT_PATTERNS: for e in varlen_p: assert e <= 1.0 -device = "cuda:0" -gpu_name = get_gpu_label() - -do_benchmarks = True -# do_benchmarks = False -quantiles = [0.5, 0.2, 0.8] -# should maybe also be controlled via env variable -force_dump_dataframes = False -enforce_numerical_correctness = True -# enforce_numerical_correctness = False -if os.getenv("TEST_ALLOW_INCORRECT", "0") == "1": - enforce_numerical_correctness = False -do_profiling = False # will add overhead to kernel runtime measured via CUDA_EVENTS -store_hatchet = False -debug_flag = os.getenv("TRITON_BACKEND_DEBUG", "0") == "1" -add_triton_dejavu_envs = True - @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -965,6 +976,7 @@ def test_prefill_vllm_v0_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("prompt_pattern", PROMPT_PATTERNS) +@pytest.mark.parametrize("batch_composition", PREFIX_PREFILL_BATCH_COMPOSITION) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("implementation", IMPLEMENTATION_UT) @@ -985,6 +997,7 @@ def test_prefix_vllm_v1_attention( block_size, num_blocks, prompt_pattern, + batch_composition, dtype, seed, implementation, @@ -1026,11 +1039,9 @@ def test_prefix_vllm_v1_attention( if realistic_prompt_mode: ATOL *= 2.2 # for 0.0313% of the cases... RTOL = 1e-5 - # TODO - if implementation == Implementation.FLASH_ATTN and decode_share != 1.0: - ATOL = 2 * max_value # for 0.0269% - if seqlen >= 512: - ATOL = 2.5 * max_value # 4.77e-05% + # TODO?? due to incomplete output batch? + if decode_share != 1.0: + ATOL = 2 * max_value torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -1047,7 +1058,7 @@ def test_prefix_vllm_v1_attention( partial_prefill_seqs = int(np.ceil(prefill_seqs * partial_prefill_share)) full_prefill_seqs = prefill_seqs - partial_prefill_seqs - # reuse same prompt pattern for partial promps, but with half the length + # reuse same prompt pattern for partial prompts, but with half the length len_fraction_half = itertools.cycle([pp * 0.5 for pp in prompt_pattern]) raw_partial_prefill_ctx_lens = [ int(np.ceil(l // block_size * next(len_fraction_half))) * block_size @@ -1072,7 +1083,7 @@ def test_prefix_vllm_v1_attention( + init_seq_lens[decode_seqs + partial_prefill_seqs :] ) ctx_lens = ( - # TODO: substract one from query length or not? (adapt assert below if changing) + # TODO: subtract one from query length or not? (adapt assert below if changing) [ol - 1 for ol in init_seq_lens[:decode_seqs]] # init_seq_lens[:decode_seqs] + partial_prefill_ctx_lens[decode_seqs : decode_seqs + partial_prefill_seqs] @@ -1081,6 +1092,24 @@ def test_prefix_vllm_v1_attention( seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] max_seq_len = max(seq_lens) + # BatchComposition.DEC_PRE is default + if batch_composition == BatchComposition.PRE_DEC: + query_lens.reverse() + ctx_lens.reverse() + seq_lens.reverse() + if batch_composition == BatchComposition.ALTERNATING: + alorder = [] + indexs_remaining = list(range(len(query_lens))) + for i in range(len(query_lens) // 2): + alorder.append(i) + alorder.append(len(query_lens) - i - 1) + indexs_remaining.remove(i) + indexs_remaining.remove(len(query_lens) - i - 1) + alorder.extend(indexs_remaining) + query_lens = [query_lens[i] for i in alorder] + ctx_lens = [ctx_lens[i] for i in alorder] + seq_lens = [seq_lens[i] for i in alorder] + if debug_flag: print( f"decode share: {decode_share}; prefill share {1-decode_share} -> of that: partial prefill share {partial_prefill_share}" @@ -1093,8 +1122,8 @@ def test_prefix_vllm_v1_attention( print("partial_prefill_ctx_lens", partial_prefill_ctx_lens) print(f"\nAfter assembling the final batch:") print(f"\tquery_lens: {query_lens}") - print(f"\tctx_lens: {ctx_lens}") - print(f"\tseq_lens: {seq_lens}") + print(f"\t ctx_lens: {ctx_lens}") + print(f"\t seq_lens: {seq_lens}") assert len(ctx_lens) == len(query_lens) if not realistic_prompt_mode: # assert max_seq_len == seqlen or max_seq_len == seqlen + 1 @@ -1156,7 +1185,7 @@ def test_prefix_vllm_v1_attention( torch.tensor([0] + query_lens, dtype=torch.int), dim=0, dtype=torch.int ) b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.int), dim=0, dtype=torch.int + torch.tensor([0] + seq_lens, dtype=torch.int), dim=0, dtype=torch.int ) # Create the block tables. @@ -1220,6 +1249,15 @@ def test_prefix_vllm_v1_attention( scale, dtype, ) + # ref_output = ref_paged_attn( + # query, + # key_cache, + # value_cache, + # b_query_lens, + # b_ctx_lens, + # block_table_t, + # scale, + # ) if implementation == Implementation.FLASH_ATTN: from callers import FlashAttnPrefixPrefillCaller as Caller @@ -1569,6 +1607,9 @@ def write_df_and_chmod(df, filename, mode=0o777): args = [__file__] filter_args = "" for ca in sys.argv[1:]: + if ".conf" in ca: + # already processed + continue if ca[0] == "-": args.append(ca) else: diff --git a/scripts/callers/flash_attn.py b/scripts/callers/flash_attn.py index 19419033c..6a63d778d 100644 --- a/scripts/callers/flash_attn.py +++ b/scripts/callers/flash_attn.py @@ -166,27 +166,51 @@ def make_call_func( max_query_len = query_lens.max() max_seqlen = seq_lens.max() - def call_and_process_output(): - # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) - return flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - out=output, - cu_seqlens_q=start_loc, - max_seqlen_q=max_query_len, - seqused_k=seq_lens, - max_seqlen_k=max_seqlen, - softmax_scale=softmax_scale, - causal=True, - block_table=block_tables, - # window_size=(-1, 1), - # softcap=0, - # fa_version=2, # TODO - ) + if torch.version.hip: + + def call_and_process_output(): + # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) + return flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + block_table=block_tables, + # window_size=(-1, 1), + # softcap=0, + # fa_version=2, # TODO + ) + + else: + + def call_and_process_output(): + # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) + return flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + out=output, + cu_seqlens_q=start_loc, + max_seqlen_q=max_query_len, + seqused_k=seq_lens, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + block_table=block_tables, + # window_size=(-1, 1), + # softcap=0, + # fa_version=2, # TODO + ) return call_and_process_output @staticmethod def requires_allocated_output() -> bool: + if torch.version.hip: + return False return True diff --git a/scripts/roofline/proton_viewer.py b/scripts/roofline/proton_viewer.py index a583d806e..ff6c7815d 100644 --- a/scripts/roofline/proton_viewer.py +++ b/scripts/roofline/proton_viewer.py @@ -263,7 +263,7 @@ def get_time_seconds(df): ) else: raise RuntimeError( - f"Metric {orig_metric} has unkown unit {unit}, cannot be derived (derivable metrics: {derivable_metrics})." + f"Metric {orig_metric} has unknown unit {unit}, cannot be derived (derivable metrics: {derivable_metrics})." ) except IndexError: # we don't have a unit @@ -469,7 +469,7 @@ def main(): - flop[<8/16/32/64>]/s, gflop[<8/16/32/64>]/s, tflop[<8/16/32/64>]/s: flops / time - byte/s, gbyte/s, tbyte/s: bytes / time - util: max(sum(flops) / peak_flops_time, sum(bytes) / peak_bandwidth_time) -- /%%: frame(metric) / sum(metric). Only availble for inclusive metrics (e.g. time) +- /%%: frame(metric) / sum(metric). Only available for inclusive metrics (e.g. time) """, ) argparser.add_argument( diff --git a/scripts/setups/default.conf b/scripts/setups/default.conf new file mode 100644 index 000000000..804ac752a --- /dev/null +++ b/scripts/setups/default.conf @@ -0,0 +1,23 @@ +DTYPES = ["float16"] +SEEDS = [0] +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +# order: num_query_heads, num_kv_heads +NUM_HEADS = [[32, 8]] + +SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] +PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] + +HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 +# head_size * head_numbers = hidden_size + +BLOCK_SIZES = [16] +NUM_BLOCKS = [4321] # "arbitrary values for testing..." + +PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] + +MAX_VALUES = [1.0] +BENCHMARK_MODES = ["CUDA_EVENTS"] + +IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_3D", "UNF_TRITON_2D"] diff --git a/scripts/setups/prefix_correctnes.conf b/scripts/setups/prefix_correctnes.conf new file mode 100644 index 000000000..3ee2b2da9 --- /dev/null +++ b/scripts/setups/prefix_correctnes.conf @@ -0,0 +1,32 @@ +# BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +BATCH_SIZES = [4] +# order: num_query_heads, num_kv_heads +NUM_HEADS = [[32, 8]] + +# SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] +# SEQUENCE_LENGTHS = [64] +SEQUENCE_LENGTHS = [1024] +# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] +# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] +PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] +PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] + +HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 +# head_size * head_numbers = hidden_size + +BLOCK_SIZES = [16] +NUM_BLOCKS = [4321] # "arbitrary values for testing..." + +# PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] +PROMPT_PATTERNS = [[1.0]] + +MAX_VALUES = [1.0] +BENCHMARK_MODES = ["CUDA_EVENTS"] + +# IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_3D", "UNF_TRITON_2D"] +# IMPLEMENTATION_UT = ["UNF_TRITON_3D", "UNF_TRITON_2D", "UNF_TRITON_AUTO"] +# IMPLEMENTATION_UT = ["UNF_TRITON_2D"] +IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] + +TRITON_BACKEND_DEBUG = 1 diff --git a/scripts/setups/prefix_correctnes_rocm.conf b/scripts/setups/prefix_correctnes_rocm.conf new file mode 100644 index 000000000..c0a9ca8cf --- /dev/null +++ b/scripts/setups/prefix_correctnes_rocm.conf @@ -0,0 +1,32 @@ +# BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +BATCH_SIZES = [4] +# order: num_query_heads, num_kv_heads +NUM_HEADS = [[32, 8]] + +# SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] +# SEQUENCE_LENGTHS = [64] +SEQUENCE_LENGTHS = [1024] +# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] +# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] +PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] +PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] + +HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 +# head_size * head_numbers = hidden_size + +BLOCK_SIZES = [128] +NUM_BLOCKS = [4321] # "arbitrary values for testing..." + +# PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] +PROMPT_PATTERNS = [[1.0]] + +MAX_VALUES = [1.0] +BENCHMARK_MODES = ["CUDA_EVENTS"] + +# IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_3D", "UNF_TRITON_2D"] +# IMPLEMENTATION_UT = ["UNF_TRITON_3D", "UNF_TRITON_2D", "UNF_TRITON_AUTO"] +# IMPLEMENTATION_UT = ["UNF_TRITON_2D"] +IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] + +TRITON_BACKEND_DEBUG = 1 diff --git a/scripts/vllm_utils.py b/scripts/vllm_utils.py index 3b7bb9aea..27fba2690 100644 --- a/scripts/vllm_utils.py +++ b/scripts/vllm_utils.py @@ -313,7 +313,7 @@ def ref_prefix_prefill( ref_outputs.append(out) else: # prefix prefill - # construct continous context + # construct continuous context block_table = block_tables_lst[i] seq_len = int(seq_lens_lst[i]) ctx_len = int(ctx_lens_lst[i]) @@ -427,3 +427,62 @@ def ref_reshape_and_cache( block_offset = block_offsets_lst[i] key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] value_cache[block_idx, :, :, block_offset] = value[i] + + +# from https://github.com/vllm-project/vllm/blob/main/tests/kernels/attention/test_triton_unified_attention.py +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx : start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) + mask |= sliding_window_mask + if soft_cap is not None and soft_cap > 0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) diff --git a/triton-dejavu b/triton-dejavu index e97223599..dfb8afd13 160000 --- a/triton-dejavu +++ b/triton-dejavu @@ -1 +1 @@ -Subproject commit e97223599b0f8b8118198087cce8456c16bfa770 +Subproject commit dfb8afd1311c12eb55f0e0401d2f4e24a1ac7a58