Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ _Now, you can process **1M context 10x faster in a single A100** using Long-cont

## 📰 News
- 🐝 [25/05/02] MMInference has been accepted at **ICML'25**.
- 👨‍💻‍ [25/04/14] [SGLang](https://github.com/sgl-project/sglang/pull/5327) and [vLLM](https://github.com/vllm-project/flash-attention/pull/33) have merged the MInference sparse attention kernel. Notably, SGLang also adapted it for FlashAttention-3. Special thanks to @zhyncs and @yinfan98 for their contributions!
- 👨‍💻‍ [25/04/14] [SGLang](https://github.com/sgl-project/sglang/pull/5327) and [vLLM](https://github.com/vllm-project/flash-attention/pull/33) have merged the MInference sparse attention kernel. _MInference already supports the optimized kernels._ Just try `pip install sglang`. You can achieve up to **1.64× (64K), 2.4× (96K), 2.9× (128K), 5.2× (256K), 8× (512K), and 15× (1M)** speedup. Notably, SGLang also adapted it for FlashAttention-3. Special thanks to @zhyncs and @yinfan98 for their contributions!
- 👾 [25/04/23] We are excited to announce the release of our multi-modality work, [MMInference](https://aka.ms/2504.16083), which use **modality-aware permutation sparse attention** to accelerate long-context VLMs. We'll present MMInference at **Microsoft Booth** and **FW-Wild at ICLR'25**. See you in Singapore!
- 🤗 [25/01/27] MInference has been integrated into [Qwen2.5-1M](https://qwenlm.github.io/blog/qwen2.5-1m/) and online services. For details, refer to the [paper](https://arxiv.org/abs/2501.15383) and the [vLLM implementation](https://github.com/vllm-project/vllm/pull/11844).
- 🪸 [25/01/23] SCBench has been accepted at **ICLR'25**.
Expand Down Expand Up @@ -117,6 +117,19 @@ Currently, we support the following LLMs:

### How to use MInference

> [!TIP]
> To benefit from fast kernel implementations, we recommend installing **SGLang** or **vLLM**.
> for sglang
> ```bash
> uv pip install "sglang[all]>=0.4.6.post4"
> ```
>
> for vllm
> ```bash
> uv pip install "vllm>=0.9.0"
> uv pip install git+https://github.com/vllm-project/flash-attention
> ```

for HF,
```diff
from transformers import pipeline
Expand Down
31 changes: 22 additions & 9 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ Environment parameters:

To demonstrate the efficiency of our method, we conducted end-to-end latency tests using the [LLaMA-3-8B-Instruct-1M](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) model. The prompts were trimmed to different target token numbers, and we measured the pre-filling stage latency without using KV cache.

> [!TIP]
> To benefit from fast kernel implementations, we recommend installing **SGLang** or **vLLM**.
> for sglang
> ```bash
> uv pip install "sglang[all]>=0.4.6.post4"
> ```
>
> for vllm
> ```bash
> uv pip install "vllm>=0.9.0"
> uv pip install git+https://github.com/vllm-project/flash-attention
> ```

1. Download the prompt:

```bash
Expand All @@ -67,15 +80,15 @@ python experiments/benchmarks/benchmark_e2e.py --run_benchmark
4. After that, you should get the end-to-end latency results like this:

```json
FlashAttention-2 A-Shape InfLLM MInference
1K 0.54565 1.07110 2.94495 2.96450
10K 0.97590 1.18339 2.21052 2.77618
50K 8.52933 5.47972 14.63624 7.54537
100K 24.88319 10.86379 27.67215 13.98508
200K 79.39184 21.61490 55.64703 26.81303
300K 169.62441 32.44844 80.74326 41.09374
500K 456.78353 54.15910 167.91472 66.27691
1000K 1765.56387 107.85639 328.58551 179.12031
FlashAttention-2 A-Shape InfLLM MInference MInference w/ SGLang
1K 0.54565 1.07110 2.94495 2.96450 1.25478
10K 0.97590 1.18339 2.21052 2.77618 2.43798
50K 8.52933 5.47972 14.63624 7.54537 6.19598
100K 24.88319 10.86379 27.67215 13.98508 10.81580
200K 79.39184 21.61490 55.64703 26.81303 20.62303
300K 169.62441 32.44844 80.74326 41.09374 31.01629
500K 456.78353 54.15910 167.91472 66.27691 51.96293
1000K 1765.56387 107.85639 328.58551 179.12031 112.37610
```

> [!TIP]
Expand Down
76 changes: 69 additions & 7 deletions minference/ops/pit_sparse_flash_attention_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft
# Copyright (c) 2024-2025 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import math
Expand All @@ -9,6 +9,26 @@

from ..cuda import convert_vertical_slash_indexes

try:
from sgl_kernel.sparse_flash_attn import sparse_attn_func
except:
try:
from vllm_flash_attn import sparse_attn_func
except:
print("To benefit from fast kernel implementations, we recommend installing SGLang or vllm.")
sparse_attn_func = None

try:
from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes as convert_vertical_slash_indexes_opt,
)
except:
try:
from vllm._custom_ops import (
convert_vertical_slash_indexes as convert_vertical_slash_indexes_opt,
)
except:
convert_vertical_slash_indexes_opt = None

# @triton.autotune(
# configs=[
Expand Down Expand Up @@ -181,8 +201,10 @@ def vertical_slash_sparse_attention(
block_size_M: int = 64,
block_size_N: int = 64,
):
if convert_vertical_slash_indexes_opt is not None:
return vertical_slash_sparse_attention_wo_pad(query, key, value, v_idx, s_idx)
batch_size, num_heads, context_size, head_dim = query.shape
Copy link

Copilot AI May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new pad computation uses a bitwise operation that assumes block_size_M is a power of two; adding a clarifying comment or an explicit check could improve code clarity.

Suggested change
batch_size, num_heads, context_size, head_dim = query.shape
batch_size, num_heads, context_size, head_dim = query.shape
# Ensure block_size_M is a power of two, as required for the bitwise operation below.
if block_size_M & (block_size_M - 1) != 0 or block_size_M <= 0:
raise ValueError("block_size_M must be a power of two and greater than zero.")
# Compute padding size. The bitwise operation assumes block_size_M is a power of two.

Copilot uses AI. Check for mistakes.
pad = block_size_M - (context_size & (block_size_M - 1))
pad = (block_size_M - context_size) & (block_size_M - 1)
query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
Expand All @@ -200,9 +222,49 @@ def vertical_slash_sparse_attention(
block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes(
seqlens, v_idx, s_idx, context_size, block_size_M, block_size_N,
)
out = _triton_mixed_sparse_attention(
query, key, value, seqlens,
block_count, block_offset, column_count, column_index,
sm_scale, block_size_M, block_size_N,
)

if sparse_attn_func is not None:
out = sparse_attn_func(
query.transpose(1, 2).contiguous(),
key.transpose(1, 2).contiguous(),
value.transpose(1, 2).contiguous(),
block_count, block_offset, column_count, column_index,
return_softmax_lse=False,
causal=True,
).transpose(1, 2).contiguous()
else:
out = _triton_mixed_sparse_attention(
query, key, value, seqlens,
block_count, block_offset, column_count, column_index,
sm_scale, block_size_M, block_size_N,
)

return out[..., :context_size, :head_dim]

def vertical_slash_sparse_attention_wo_pad(query, key, value, v_idx, s_idx, block_size_M: int = 64, block_size_N: int = 64):
batch_size, num_heads, context_size, head_dim = query.shape
seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes_opt(
seqlens,
seqlens,
v_idx.to(torch.int32),
s_idx.to(torch.int32),
context_size,
block_size_M,
block_size_N,
causal=True,
)
)
out = sparse_attn_func(
query.transpose(1, 2).contiguous(),
key.transpose(1, 2).contiguous(),
value.transpose(1, 2).contiguous(),
block_count,
block_offset,
column_count,
column_index,
causal=True,
return_softmax_lse=False,
)
return out.transpose(1, 2).contiguous()
2 changes: 1 addition & 1 deletion minference/ops/xattention_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size
try:
assert k_len % segment_size == 0
except:
breakpoint()
assert False, f"xAttention error, k_len: {k_len}, segment size: {segment_size}"
Comment on lines 248 to +249
Copy link

Copilot AI May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using 'assert False' for error handling may be less informative in production; consider raising a specific exception (e.g., RuntimeError) with the same message.

Copilot uses AI. Check for mistakes.
assert segment_size % reshaped_block_size == 0
assert attn_weights_slice.stride(-1) == 1

Expand Down