Skip to content

Triton implement of bi-directional (non-causal) linear attention

License

Notifications You must be signed in to change notification settings

hp-l33/flash-bidirectional-linear-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Bi-directional Linear Attention

The aim of this repository is to implement bi-directional linear attention for non-causal modeling using Triton.

image

This project is currently maintained by an individual and remains a work in progress. As the maintainer is still in the early stages of learning Triton, many implementations may not be optimal. Contributions and suggestions are welcome!

Models

Roughly sorted according to the timeline supported in FBi-LA

Date Model Title Paper Code FBi-LA impl
2024-11 Linfusion LinFusion: 1 GPU, 1 Minute, 16K Image arxiv official code
2024-11 MLLA Demystify Mamba in Vision: A Linear Attention Perspective arxiv official code
2024-11 Focused-LA FLatten Transformer: Vision Transformer using Focused Linear Attention arxiv official code

More models will be implemented gradually.

P.S.: The current implementation of MLLA is relatively basic and will be updated soon.

Usage

Installation

git clone https://github.com/hp-l33/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.

Integrated Models

This library has integrated some models, which can be called directly. Taking LinFusion as an example:

import torch
from diffusers import AutoPipelineForText2Image
from fbi_la.models import LinFusion

sd_repo = "Lykon/dreamshaper-8"

pipeline = AutoPipelineForText2Image.from_pretrained(
    sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

linfusion = LinFusion.construct_for(pipeline)

image = pipeline(
    "An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
    generator=torch.manual_seed(123)
).images[0]

Benchmarks

Tested on an A800 80G GPU.

B8-H16-D64:
         T  torch_fwd  triton_fwd  torch_bwd  triton_bwd
0    128.0   0.063488    0.049152   0.520192    0.651264
1    256.0   0.080896    0.056320   0.795648    0.599040
2    512.0   0.111616    0.070656   1.074176    1.065984
3   1024.0   0.169984    0.101376   1.014784    0.746496
4   2048.0   0.300032    0.165888   1.464320    1.364992
5   4096.0   0.532480    0.287744   2.741248    2.564096
6   8192.0   1.005568    0.521216   5.232128    4.940800
7  16384.0   1.924608    0.980992  10.235904    9.695744

TODO

  • improve memory efficiency during backpropagation
  • replace torch.sum() and torch.mean() operations
  • implement more models
    • VSSD
    • RALA

Acknowledgments

Thanks to the following repositories for their inspiration:

About

Triton implement of bi-directional (non-causal) linear attention

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages