The aim of this repository is to implement bi-directional linear attention for non-causal modeling using Triton.
This project is currently maintained by an individual and remains a work in progress. As the maintainer is still in the early stages of learning Triton, many implementations may not be optimal. Contributions and suggestions are welcome!
Roughly sorted according to the timeline supported in FBi-LA
Date | Model | Title | Paper | Code | FBi-LA impl |
---|---|---|---|---|---|
2024-11 | Linfusion | LinFusion: 1 GPU, 1 Minute, 16K Image | arxiv | official | code |
2024-11 | MLLA | Demystify Mamba in Vision: A Linear Attention Perspective | arxiv | official | code |
2024-11 | Focused-LA | FLatten Transformer: Vision Transformer using Focused Linear Attention | arxiv | official | code |
More models will be implemented gradually.
P.S.: The current implementation of MLLA is relatively basic and will be updated soon.
git clone https://github.com/hp-l33/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.
This library has integrated some models, which can be called directly. Taking LinFusion as an example:
import torch
from diffusers import AutoPipelineForText2Image
from fbi_la.models import LinFusion
sd_repo = "Lykon/dreamshaper-8"
pipeline = AutoPipelineForText2Image.from_pretrained(
sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))
linfusion = LinFusion.construct_for(pipeline)
image = pipeline(
"An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
generator=torch.manual_seed(123)
).images[0]
Tested on an A800 80G GPU.
B8-H16-D64:
T torch_fwd triton_fwd torch_bwd triton_bwd
0 128.0 0.063488 0.049152 0.520192 0.651264
1 256.0 0.080896 0.056320 0.795648 0.599040
2 512.0 0.111616 0.070656 1.074176 1.065984
3 1024.0 0.169984 0.101376 1.014784 0.746496
4 2048.0 0.300032 0.165888 1.464320 1.364992
5 4096.0 0.532480 0.287744 2.741248 2.564096
6 8192.0 1.005568 0.521216 5.232128 4.940800
7 16384.0 1.924608 0.980992 10.235904 9.695744
- improve memory efficiency during backpropagation
- replace
torch.sum()
andtorch.mean()
operations - implement more models
- VSSD
- RALA
Thanks to the following repositories for their inspiration: