Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove xformers and replace with torch native memory efficient attention #82

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ model:
attention_dropout: 0.0

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: False


dataset:
Expand Down Expand Up @@ -259,7 +258,6 @@ __model__:
- `model.vq_model.pretrained`: The pretrained vq model to use. Can be a path to a saved checkpoint or a huggingface model name.
- `model.transformer`: The transformer model configuration.
- `model.gradient_checkpointing`: Enable gradient checkpointing for the transformer model.
- `enable_xformers_memory_efficient_attention`: Enable memory efficient attention or flash attention for the transformer model. For flash attention we need to use `fp16` or `bf16`. [xformers](https://github.com/facebookresearch/xformers) needs to be installed for this to work.

__dataset__:
- `dataset.params.train_shards_path_or_url`: The path or url to the `webdataset` training shards.
Expand Down
1 change: 0 additions & 1 deletion configs/cc12m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
2 changes: 0 additions & 2 deletions configs/cc12m_movq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ model:
patch_size: 2

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
type: "text2image"
Expand Down
1 change: 0 additions & 1 deletion configs/cc12m_uvit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/cc12m_uvit_clip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/cc12m_uvit_larger_paellavq_f8_clip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/cc12m_uvit_paellavq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/cc12m_uvit_paellavq_larger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/imagenet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ model:
attention_dropout: 0.0

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/imagenet_movq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ model:
attention_dropout: 0.0

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/imagenet_text2image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ model:
attention_dropout: 0.0

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/imagenet_text2image_movq_conv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ model:
patch_size: 2

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/laiona6plus_uvit_clip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/laiona6plus_uvit_clip_f8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ model:
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
Expand Down
1 change: 0 additions & 1 deletion configs/template_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ model:
attention_dropout: 0.0

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: False


dataset:
Expand Down
41 changes: 1 addition & 40 deletions muse/modeling_movq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@

from .modeling_utils import ConfigMixin, ModelMixin, register_to_config

try:
import xformers.ops as xops

is_xformers_available = True
except ImportError:
is_xformers_available = False


class SpatialNorm(nn.Module):
def __init__(
Expand Down Expand Up @@ -170,17 +163,6 @@ def __init__(self, in_channels, zq_ch=None, add_conv=False):
self.v = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)

self.use_memory_efficient_attention_xformers = False
self.xformers_attention_op = None

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers and not is_xformers_available:
raise ImportError("Please install xformers to use memory efficient attention")
self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.xformers_attention_op = attention_op

def forward(self, hidden_states, zq=None):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
Expand All @@ -190,33 +172,12 @@ def forward(self, hidden_states, zq=None):
hidden_states = self.norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
scale = 1.0 / torch.sqrt(torch.tensor(channel, dtype=hidden_states.dtype, device=hidden_states.device))

query = self.q(hidden_states)
key = self.k(hidden_states)
value = self.v(hidden_states)

if self.use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xops.memory_efficient_attention(
query, key, value, attn_bias=None, op=self.xformers_attention_op
)
else:
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query,
key.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = F.scaled_dot_product_attention(query, key, value)

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).view(batch, channel, height, width)
Expand Down
70 changes: 13 additions & 57 deletions muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
from .sampling import cosine_schedule, gumbel_sample, mask_by_random_topk, top_k

try:
import xformers.ops as xops

is_xformers_available = True
except ImportError:
is_xformers_available = False


# classifier free guidance functions


Expand Down Expand Up @@ -116,7 +108,7 @@ def forward(self, x):


class GlobalResponseNorm(nn.Module):
"Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105

def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -309,21 +301,7 @@ def __init__(self, hidden_size, num_heads, encoder_hidden_size=None, attention_d
self.out = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias)
self.dropout = nn.Dropout(attention_dropout)

self.use_memory_efficient_attention_xformers = False
self.xformers_attention_op = None

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers and not is_xformers_available:
raise ImportError("Please install xformers to use memory efficient attention")
self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.xformers_attention_op = attention_op

def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None):
if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers:
raise ValueError("Memory efficient attention does not yet support encoder attention mask")

context = hidden_states if encoder_hidden_states is None else encoder_hidden_states
batch, q_seq_len, _ = hidden_states.shape
kv_seq_len = q_seq_len if encoder_hidden_states is None else encoder_hidden_states.shape[1]
Expand All @@ -332,43 +310,21 @@ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_m
key = self.key(context)
value = self.value(context)

query = query.view(batch, q_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
query = query.view(batch, q_seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, nh, T, hs)
key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, nh, T, hs)
value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, nh, T, hs)

if self.use_memory_efficient_attention_xformers:
attn_output = xops.memory_efficient_attention(query, key, value, op=self.xformers_attention_op)
attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
else:
attention_mask = None
if encoder_attention_mask is not None:
src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device)
attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype)
attn_output = self.attention(query, key, value, attention_mask)
attention_mask = None
if encoder_attention_mask is not None:
src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device)
attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype)

attn_output = self.out(attn_output)
return attn_output
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) # (B, nh, T, hs)
attn_output = attn_output.permute(0, 2, 1, 3).reshape(
batch, q_seq_len, self.num_heads * self.head_dim
) # (B, T, q_dim)

def attention(self, query, key, value, attention_mask=None):
batch, seq_len = query.shape[:2]
kv_seq_len = key.shape[1]
query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value)) # (B, nh, T, hs)

attn_weights = torch.baddbmm(
input=torch.zeros(batch * self.num_heads, seq_len, kv_seq_len, dtype=query.dtype, device=query.device),
batch1=query.view(batch * self.num_heads, seq_len, self.head_dim),
batch2=key.view(batch * self.num_heads, kv_seq_len, self.head_dim).transpose(1, 2),
alpha=1 / self.scale_attn,
)
attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len) # -1 is kv_seq_len
# Apply the attention mask
if attention_mask is not None:
attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# re-assemble all head outputs side by side
attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size)
attn_output = self.out(attn_output)
return attn_output


Expand Down
55 changes: 0 additions & 55 deletions muse/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,61 +273,6 @@ def disable_gradient_checkpointing(self):
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))

def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid, attention_op)

for child in module.children():
fn_recursive_set_mem_eff(child)

for module in self.children():
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)

def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.

Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.

Examples:

```py
>>> import torch
>>> from diffusers import UNet2DConditionModel
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

>>> model = UNet2DConditionModel.from_pretrained(
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
... )
>>> model = model.to("cuda")
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
```
"""
self.set_use_memory_efficient_attention_xformers(True, attention_op)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.set_use_memory_efficient_attention_xformers(False)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
2 changes: 1 addition & 1 deletion muse/pipeline_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def from_pretrained(
# TODO: Add config for pipeline to specify text encoder
is_clip = "clip" in text_encoder_args["pretrained_model_name_or_path"]
text_encoder_cls = CLIPTextModel if is_clip else T5EncoderModel

text_encoder = text_encoder_cls.from_pretrained(**text_encoder_args)
tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)

Expand Down
1 change: 0 additions & 1 deletion scripts/benchmark_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def create_model_and_benchmark(args):
time_vanilla_fp16 = benchmark_torch_function(f)

print("Running benchmark for efficient attention in FP16 ...")
model.enable_xformers_memory_efficient_attention()
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_efficient_fp16 = benchmark_torch_function(f)

Expand Down
1 change: 0 additions & 1 deletion scripts/calculate_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def generate_and_save_images(args):

print("Loading pipe")
pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device)
pipeline.transformer.enable_xformers_memory_efficient_attention()

print("Loading data")
dataset = Flickr8kDataset(args.dataset_root, args.dataset_captions_file)
Expand Down
1 change: 0 additions & 1 deletion scripts/log_generations_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def generate_and_log(args):
transformer_path=args.transformer,
is_class_conditioned=args.is_class_conditioned,
).to(device=args.device)
pipe.transformer.enable_xformers_memory_efficient_attention()

imagenet_class_ids = list(range(1000))
with open(args.imagenet_class_mapping_path) as f:
Expand Down
Loading