We are continuously updating more features. You could Star and Watch our repository to stay updated.
This repository provides the official implementation of SageAttention and SageAttention2.
SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
Paper: https://arxiv.org/abs/2410.02367
Jintao Zhang, Jia Wei, Haofeng Huang, Pengle Zhang, Jun Zhu, Jianfei Chen
SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization
Paper: https://arxiv.org/abs/2411.10958
Jintao Zhang, Haofeng Huang, Pengle Zhang, Jia Wei, Jun Zhu, Jianfei Chen
This is a beta release of SageAttention2. We welcome any feedback on accuracy, performance issues, bugs, feature requests, or suggestions. Please feel free to open an issue or launch a pull request!
Current Features:
- INT8 quantization for
$QK^\top$ with support for varying granularities - FP8 quantization for
$PV$ - FP32 buffer for
$PV$ to improve accuracy in FP8 MMA - Support
torch.compile
with non-cudagraphs mode and distributed inference
For a stable version, please use the branch of SageAttention-1 branch.
- News [2024-12-20]: Update the SageAttention2 Paper.
- News [2024-12-20]: We are excited to announce the release of SageAttention 2.0.1 Beta! In this version, we introduce a new feature: per-thread quantization, which offers finer granularity while maintaining hardware efficiency.
- News [2024-11-21]: SageAttention 2.0.0 beta is released! Now SageAttention has measured speedup on L20, L40, A100, A800, and A6000 other than RTX3090 and RTX4090.
- News [2024-11-12]: Support for
sageattn_varlen
is available now. - News [2024-11-11]: Support for different sequence lengths between
q
andk,v
,(batch_size, head_num, seq_len, head_dim)
or(batch_size, seq_len, head_num, head_dim)
input shapes, andgroup-query attention
is available now.
python>=3.9
torch>=2.3.0
triton>=3.0.0
CUDA>=12.4
if you want to use fp8 elseCUDA>=12.0
flash-attn
for benchmarking
For the stable version or Triton-only version, refer to SageAttention-1 and install using pip:
pip install sageattention==1.0.6
To use SageAttention 2.0.1, please compile from source:
git clone https://github.com/thu-ml/SageAttention.git
cd sageattention
python setup.py install # or pip install -e .
Note: Currently, SageAttention is optimized for excellent performance on RTX4090, RTX3090, L20, and L40 GPUs. On A100, A800, and A6000 GPUs, performance is best with a
head_dim=128
, whilehead_dim=64
is less optimal. Similarly, performance on the Hopper architecture is currently not optimal. We are actively working to enhance performance in these configurations.
from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
q, k, v
are FP16/BF16 dtype with the shape(batch_size, head_num, seq_len, head_dim)
using defaulttensor_layout="HND"
. For shape(batch_size, seq_len, head_num, head_dim)
, settensor_layout="NHD"
.is_causal
determines the use of a causal mask.
-
sageattn
: Automatically selects the optimal kernel based on the GPU to achieve a good performance-accuracy trade-off. -
sageattn_qk_int8_pv_fp16_triton
: INT8 quantization for$QK^\top$ and FP16 for$PV$ using Triton backend. -
sageattn_qk_int8_pv_fp16_cuda
: INT8 quantization for$QK^\top$ and FP16 for$PV$ using CUDA backend. -
sageattn_qk_int8_pv_fp8_cuda
: INT8 quantization for$QK^\top$ and FP8 for$PV$ using CUDA backend. -
sageattn_varlen
: INT8 quantization for$QK^\top$ and FP16 for$PV$ using Triton backend. Support for varying sequence lengths within the same batch.
For optimal speed and accuracy performance on custom devices and models, we strongly recommend referring to the this file for detailed guidance.
Note: Support for different sequence lengths between
q
andk,v
andgroup-query attention
is available.
We can replace scaled_dot_product_attention
easily.
We will take CogvideoX as an example:
Add the following codes and run
from sageattention import sageattn
import torch.nn.functional as F
F.scaled_dot_product_attention = sageattn
Specifically,
cd example
python sageattn_cogvideo.py --compile
You can get a lossless video in ./example
faster than by using python original_cogvideo.py --compile
Note: Not all models works with
F.scaled_dot_product_attention = sageattn
. Technically, you should replace the original Attention by modifying theAttention Class
of the target model. For image and video models, we suggest only replacing the attention in DiT.
8+8
means the kernel with INT8 quantization for 8+16
uses FP16 with FP16 accumulator for
Note: The TOPS results refer only to the Attention Kernel, excluding the quantization and smoothing.
If you use this code or find our work valuable, please cite:
@misc{zhang2024sageattention,
title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration},
author={Jintao Zhang and Jia wei and Haofeng Huang and Pengle Zhang and Jun Zhu and Jianfei Chen},
year={2024},
eprint={2410.02367},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.02367},
}
@misc{zhang2024sageattention2,
title={SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization},
author={Jintao Zhang and Haofeng Huang and Pengle Zhang and Jia Wei and Jun Zhu and Jianfei Chen},
year={2024},
eprint={2411.10958},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.10958},
}