Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #1036

Merged
merged 6 commits into from
Oct 8, 2024
Merged

merge #1036

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ When you're training every component of a rank-16 LoRA (MLP, projections, multim
- a bit more than 30G VRAM when not quantising the base model
- a bit more than 18G VRAM when quantising to int8 + bf16 base/LoRA weights
- a bit more than 13G VRAM when quantising to int4 + bf16 base/LoRA weights
- a bit more than 9G VRAM when quantising to NF4 + bf16 base/LoRA weights
- a bit more than 9G VRAM when quantising to int2 + bf16 base/LoRA weights

You'll need:
- **the absolute minimum** is a single 4060 Ti 16GB
- **the absolute minimum** is a single **3080 10G**
- **a realistic minimum** is a single 3090 or V100 GPU
- **ideally** multiple 4090, A6000, L40S, or better

Expand Down Expand Up @@ -387,6 +388,22 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with:

Speed was approximately 1.4 iterations per second on a 4090.

### NF4-quantised training

In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address.

In early tests, the following holds true:
- Lion optimiser causes model collapse but uses least VRAM; AdamW variants help to hold it together; bnb-adamw8bit, adamw_bf16 are great choices
- AdEMAMix didn't fare well, but settings were not explored
- `--max_grad_norm=0.01` further helps reduce model breakage by preventing huge changes to the model in too short a time
- NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used
- Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090)
- Anything that's difficult to train on int8 or bf16 becomes harder in NF4

NF4 does not work with torch.compile, so whatever you get for speed is what you get.

If VRAM is not a concern (eg. 48G or greater) then int8 with torch.compile is your best, fastest option.

### Classifier-free guidance

#### Problem
Expand Down
257 changes: 58 additions & 199 deletions helpers/models/flux/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Originally licensed under the Apache License, Version 2.0 (the "License");
# Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -14,7 +14,6 @@
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import (
Attention,
apply_rope,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import (
Expand All @@ -33,97 +32,14 @@
from diffusers.models.embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
FluxPosEmbed,
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)

batch_size, _, _ = hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)

if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (attention_mask > 0).bool()
attention_mask = attention_mask.to(
device=hidden_states.device, dtype=hidden_states.dtype
)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
attn_mask=attention_mask,
)

hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

return hidden_states


class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

Expand All @@ -141,20 +57,7 @@ def __call__(
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)

batch_size = encoder_hidden_states.shape[0]
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

# `sample` projections.
query = attn.to_q(hidden_states)
Expand All @@ -173,41 +76,42 @@ def __call__(
if attn.norm_k is not None:
key = attn.norm_k(key)

# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
from diffusers.models.embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
Expand All @@ -229,71 +133,20 @@ def __call__(
)
hidden_states = hidden_states.to(query.dtype)

encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

return hidden_states, encoder_hidden_states


# YiYi to-do: refactor rope related functions/classes
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

scale = (
torch.arange(
0,
dim,
2,
dtype=torch.float32,
device=pos.device,
)
/ dim
)
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)

stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()


# YiYi to-do: refactor rope related functions/classes
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)

return emb.unsqueeze(1)
return hidden_states, encoder_hidden_states
return hidden_states


def expand_flux_attention_mask(
Expand Down Expand Up @@ -338,7 +191,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)

processor = FluxSingleAttnProcessor2_0()
processor = FluxAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
Expand Down Expand Up @@ -536,16 +389,16 @@ def __init__(
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = (
self.config.num_attention_heads * self.config.attention_head_dim
)

self.pos_embed = EmbedND(
dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
self.pos_embed = FluxPosEmbed(
theta=10000, axes_dim=axes_dims_rope
)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings
Expand Down Expand Up @@ -667,7 +520,13 @@ def forward(
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

ids = torch.cat((txt_ids, img_ids), dim=1)
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)

image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
Expand Down
6 changes: 4 additions & 2 deletions helpers/training/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def load_diffusion_model(args, weight_dtype):
"revision": args.revision,
"variant": args.variant,
"torch_dtype": weight_dtype,
"use_safetensors": True,
}
unet = None
transformer = None
Expand Down Expand Up @@ -64,8 +65,9 @@ def load_diffusion_model(args, weight_dtype):
)

transformer = FluxTransformer2DModelWithMasking.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
args.pretrained_transformer_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
**pretrained_load_args,
)
elif args.model_family == "pixart_sigma":
Expand Down
Loading
Loading