A set of implementations for the Differential Transformer paper [1] using PyTorch's Scaled Dot Product Attention instead of the provided implementations over there:
- Basic manual PyTorch
- Flash Attention 2
- Custom kernel to handle differing
head_dim
more efficiently - Original kernel that is more optimized on same
head_dim
- Custom kernel to handle differing
This implementation has four variations as of now:
- Following the original Flash Attention 2 implementation more closely
- Following the custom Flash Attention 2 implementation more closely
- One forward pass to the attention calculations (transferable to original Flash Attention 2 implementation)
- One forward pass to the attention calculations based on [2] (utilizing SDPA different
head_dim
capability)
Note:
- RoPE is optional as I only cared about equivalency first and foremost
- Needs external proper handling of RoPE and Attention Masks
- It really needs benchmarks to see what is working better especially regarding both
one pass versions
- Same
head_dim
, morenum_heads
but concatenating and chunking/unbinding - Different
head_dim
, lessnum_heads
but possibly less utilization on original Flash Attention 2
- Same
I won't distribute a pypi package, but you can use it as package by cloning the repo and installing it at root:
git clone https://github.com/vasqu/multihead-sdpadiff.git
cd multihead-sdpadiff
pip install .
import torch
from multihead_sdpadiff import (
MultiheadSdpaDiff1, # multiple attn passes
MultiheadSdpaDiff2, # two attn passes
MultiheadSdpaDiff3, # one attn pass (v1)
MultiheadSdpaDiff4, # one attn pass (v2)
)
# some shape values
bsz = 2
seq_len = 3
depth = 12
embed_dim = 768
num_heads = 12 # this will be set to half as we double them for the diff
# random input
x = torch.randn(size=(bsz, seq_len, embed_dim))
# choose an implementation
#sdpa_mha_diff = MultiheadSdpaDiff1(embed_dim, depth, num_heads, num_heads)
#sdpa_mha_diff = MultiheadSdpaDiff2(embed_dim, depth, num_heads, num_heads)
#sdpa_mha_diff = MultiheadSdpaDiff3(embed_dim, depth, num_heads, num_heads)
sdpa_mha_diff = MultiheadSdpaDiff4(embed_dim, depth, num_heads, num_heads)
# pass and check
res = sdpa_mha_diff(x)
assert res.shape == x.shape
- Make it a package structure
- Benchmark the speed/memory between the implementations
- Transformer style RoPE + Attn Mask
[1]
@misc{ye2024differentialtransformer,
title={Differential Transformer},
author={Tianzhu Ye and Li Dong and Yuqing Xia and Yutao Sun and Yi Zhu and Gao Huang and Furu Wei},
year={2024},
eprint={2410.05258},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.05258},
}
[2] Thanks for MarktHart for providing another version which might be the most optimized one