Skip to content

Commit

Permalink
Up to 2x speedup on GPUs using memory efficient attention (open-mmlab…
Browse files Browse the repository at this point in the history
…#532)

* 2x speedup using memory efficient attention

* remove einops dependency

* Swap K, M in op instantiation

* Simplify code, remove unnecessary maybe_init call and function, remove unused self.scale parameter

* make xformers a soft dependency

* remove one-liner functions

* change one letter variable to appropriate names

* Remove Env variable dependency, remove MemoryEfficientCrossAttention class and use enable_xformers_memory_efficient_attention method

* Add memory efficient attention toggle to img2img and inpaint pipelines

* Clearer management of xformers' availability

* update optimizations markdown to add info about memory efficient attention

* add benchmarks for TITAN RTX

* More detailed explanation of how the mem eff benchmark were ran

* Removing autocast from optimization markdown

* import_utils: import torch only if is available

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
  • Loading branch information
MatthieuToulemont and NouamaneTazi authored Nov 2, 2022
1 parent a793b1f commit 98c4213
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 4 deletions.
39 changes: 39 additions & 0 deletions docs/source/optimization/fp16.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for
| fp16 | 3.61s | x2.63 |
| channels last | 3.30s | x2.88 |
| traced UNet | 3.21s | x2.96 |
| memory efficient attention | 2.63s | x3.61 |

<em>
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
Expand Down Expand Up @@ -290,3 +291,41 @@ pipe.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```


## Memory Efficient Attention
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):

| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|------------------ |--------------------- |--------------------------------- |
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
| NVIDIA A10G | 8.88it/s | 15.6it/s |
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
| A100-SXM-80GB | 18.7it/s | 29.5it/s |

To leverage it just make sure you have:
- PyTorch > 1.12
- Cuda available
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
```python
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")

pipe.enable_xformers_memory_efficient_attention()

with torch.inference_mode():
sample = pipe("a small cat")

# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()
```
55 changes: 51 additions & 4 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
import torch.nn.functional as F
from torch import nn

from diffusers.utils.import_utils import is_xformers_available


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class AttentionBlock(nn.Module):
"""
Expand Down Expand Up @@ -150,6 +159,10 @@ def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape
Expand Down Expand Up @@ -206,6 +219,32 @@ def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers

def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
Expand Down Expand Up @@ -239,6 +278,7 @@ def __init__(
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self._use_memory_efficient_attention_xformers = False

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
Expand Down Expand Up @@ -279,11 +319,13 @@ def forward(self, hidden_states, context=None, mask=None):
# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of

if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)

# linear proj
hidden_states = self.to_out[0](hidden_states)
Expand Down Expand Up @@ -341,6 +383,11 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _memory_efficient_attention_xformers(self, query, key, value):
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class FeedForward(nn.Module):
r"""
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
Expand Down Expand Up @@ -542,6 +546,10 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()

Expand Down Expand Up @@ -1117,6 +1125,10 @@ def set_attention_slice(self, slice_size):

self.gradient_checkpointing = False

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(
self,
hidden_states,
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ def set_attention_slice(self, slice_size):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False

_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
import torch

if torch.__version__ < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -205,6 +217,10 @@ def is_scipy_available():
return _scipy_available


def is_xformers_available():
return _xformers_available


def is_accelerate_available():
return _accelerate_available

Expand Down

0 comments on commit 98c4213

Please sign in to comment.