diff --git a/README.md b/README.md index 81cf1ba..5e66d63 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ _Now, you can process **1M context 10x faster in a single A100** using Long-cont ## 📰 News - 🐝 [25/05/02] MMInference has been accepted at **ICML'25**. -- 👨‍💻‍ [25/04/14] [SGLang](https://github.com/sgl-project/sglang/pull/5327) and [vLLM](https://github.com/vllm-project/flash-attention/pull/33) have merged the MInference sparse attention kernel. Notably, SGLang also adapted it for FlashAttention-3. Special thanks to @zhyncs and @yinfan98 for their contributions! +- 👨‍💻‍ [25/04/14] [SGLang](https://github.com/sgl-project/sglang/pull/5327) and [vLLM](https://github.com/vllm-project/flash-attention/pull/33) have merged the MInference sparse attention kernel. _MInference already supports the optimized kernels._ Just try `pip install sglang`. You can achieve up to **1.64× (64K), 2.4× (96K), 2.9× (128K), 5.2× (256K), 8× (512K), and 15× (1M)** speedup. Notably, SGLang also adapted it for FlashAttention-3. Special thanks to @zhyncs and @yinfan98 for their contributions! - 👾 [25/04/23] We are excited to announce the release of our multi-modality work, [MMInference](https://aka.ms/2504.16083), which use **modality-aware permutation sparse attention** to accelerate long-context VLMs. We'll present MMInference at **Microsoft Booth** and **FW-Wild at ICLR'25**. See you in Singapore! - 🤗 [25/01/27] MInference has been integrated into [Qwen2.5-1M](https://qwenlm.github.io/blog/qwen2.5-1m/) and online services. For details, refer to the [paper](https://arxiv.org/abs/2501.15383) and the [vLLM implementation](https://github.com/vllm-project/vllm/pull/11844). - 🪸 [25/01/23] SCBench has been accepted at **ICLR'25**. @@ -117,6 +117,19 @@ Currently, we support the following LLMs: ### How to use MInference +> [!TIP] +> To benefit from fast kernel implementations, we recommend installing **SGLang** or **vLLM**. +> for sglang +> ```bash +> uv pip install "sglang[all]>=0.4.6.post4" +> ``` +> +> for vllm +> ```bash +> uv pip install "vllm>=0.9.0" +> uv pip install git+https://github.com/vllm-project/flash-attention +> ``` + for HF, ```diff from transformers import pipeline diff --git a/experiments/README.md b/experiments/README.md index 1598f3a..7a1c4b5 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -42,6 +42,19 @@ Environment parameters: To demonstrate the efficiency of our method, we conducted end-to-end latency tests using the [LLaMA-3-8B-Instruct-1M](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) model. The prompts were trimmed to different target token numbers, and we measured the pre-filling stage latency without using KV cache. +> [!TIP] +> To benefit from fast kernel implementations, we recommend installing **SGLang** or **vLLM**. +> for sglang +> ```bash +> uv pip install "sglang[all]>=0.4.6.post4" +> ``` +> +> for vllm +> ```bash +> uv pip install "vllm>=0.9.0" +> uv pip install git+https://github.com/vllm-project/flash-attention +> ``` + 1. Download the prompt: ```bash @@ -67,15 +80,15 @@ python experiments/benchmarks/benchmark_e2e.py --run_benchmark 4. After that, you should get the end-to-end latency results like this: ```json - FlashAttention-2 A-Shape InfLLM MInference -1K 0.54565 1.07110 2.94495 2.96450 -10K 0.97590 1.18339 2.21052 2.77618 -50K 8.52933 5.47972 14.63624 7.54537 -100K 24.88319 10.86379 27.67215 13.98508 -200K 79.39184 21.61490 55.64703 26.81303 -300K 169.62441 32.44844 80.74326 41.09374 -500K 456.78353 54.15910 167.91472 66.27691 -1000K 1765.56387 107.85639 328.58551 179.12031 + FlashAttention-2 A-Shape InfLLM MInference MInference w/ SGLang +1K 0.54565 1.07110 2.94495 2.96450 1.25478 +10K 0.97590 1.18339 2.21052 2.77618 2.43798 +50K 8.52933 5.47972 14.63624 7.54537 6.19598 +100K 24.88319 10.86379 27.67215 13.98508 10.81580 +200K 79.39184 21.61490 55.64703 26.81303 20.62303 +300K 169.62441 32.44844 80.74326 41.09374 31.01629 +500K 456.78353 54.15910 167.91472 66.27691 51.96293 +1000K 1765.56387 107.85639 328.58551 179.12031 112.37610 ``` > [!TIP] diff --git a/minference/ops/pit_sparse_flash_attention_v2.py b/minference/ops/pit_sparse_flash_attention_v2.py index 825a6b0..5c716d0 100644 --- a/minference/ops/pit_sparse_flash_attention_v2.py +++ b/minference/ops/pit_sparse_flash_attention_v2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Microsoft +# Copyright (c) 2024-2025 Microsoft # Licensed under The MIT License [see LICENSE for details] import math @@ -9,6 +9,26 @@ from ..cuda import convert_vertical_slash_indexes +try: + from sgl_kernel.sparse_flash_attn import sparse_attn_func +except: + try: + from vllm_flash_attn import sparse_attn_func + except: + print("To benefit from fast kernel implementations, we recommend installing SGLang or vllm.") + sparse_attn_func = None + +try: + from sgl_kernel.sparse_flash_attn import ( + convert_vertical_slash_indexes as convert_vertical_slash_indexes_opt, + ) +except: + try: + from vllm._custom_ops import ( + convert_vertical_slash_indexes as convert_vertical_slash_indexes_opt, + ) + except: + convert_vertical_slash_indexes_opt = None # @triton.autotune( # configs=[ @@ -181,8 +201,10 @@ def vertical_slash_sparse_attention( block_size_M: int = 64, block_size_N: int = 64, ): + if convert_vertical_slash_indexes_opt is not None: + return vertical_slash_sparse_attention_wo_pad(query, key, value, v_idx, s_idx) batch_size, num_heads, context_size, head_dim = query.shape - pad = block_size_M - (context_size & (block_size_M - 1)) + pad = (block_size_M - context_size) & (block_size_M - 1) query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) @@ -200,9 +222,49 @@ def vertical_slash_sparse_attention( block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( seqlens, v_idx, s_idx, context_size, block_size_M, block_size_N, ) - out = _triton_mixed_sparse_attention( - query, key, value, seqlens, - block_count, block_offset, column_count, column_index, - sm_scale, block_size_M, block_size_N, - ) + + if sparse_attn_func is not None: + out = sparse_attn_func( + query.transpose(1, 2).contiguous(), + key.transpose(1, 2).contiguous(), + value.transpose(1, 2).contiguous(), + block_count, block_offset, column_count, column_index, + return_softmax_lse=False, + causal=True, + ).transpose(1, 2).contiguous() + else: + out = _triton_mixed_sparse_attention( + query, key, value, seqlens, + block_count, block_offset, column_count, column_index, + sm_scale, block_size_M, block_size_N, + ) + return out[..., :context_size, :head_dim] + +def vertical_slash_sparse_attention_wo_pad(query, key, value, v_idx, s_idx, block_size_M: int = 64, block_size_N: int = 64): + batch_size, num_heads, context_size, head_dim = query.shape + seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) + block_count, block_offset, column_count, column_index = ( + convert_vertical_slash_indexes_opt( + seqlens, + seqlens, + v_idx.to(torch.int32), + s_idx.to(torch.int32), + context_size, + block_size_M, + block_size_N, + causal=True, + ) + ) + out = sparse_attn_func( + query.transpose(1, 2).contiguous(), + key.transpose(1, 2).contiguous(), + value.transpose(1, 2).contiguous(), + block_count, + block_offset, + column_count, + column_index, + causal=True, + return_softmax_lse=False, + ) + return out.transpose(1, 2).contiguous() diff --git a/minference/ops/xattention_fa.py b/minference/ops/xattention_fa.py index 67ffb30..549d3fd 100644 --- a/minference/ops/xattention_fa.py +++ b/minference/ops/xattention_fa.py @@ -246,7 +246,7 @@ def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size try: assert k_len % segment_size == 0 except: - breakpoint() + assert False, f"xAttention error, k_len: {k_len}, segment size: {segment_size}" assert segment_size % reshaped_block_size == 0 assert attn_weights_slice.stride(-1) == 1