From fc6c0e8dd4dce4f6fd4738b1fbcbd699985561c9 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 4 Oct 2024 11:42:06 -0600 Subject: [PATCH 1/4] quanto: improve support for SDXL training --- .../quantisation/quanto_workarounds.py | 31 +++++++++++++------ .../quantisation/torchao_workarounds.py | 2 +- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/helpers/training/quantisation/quanto_workarounds.py b/helpers/training/quantisation/quanto_workarounds.py index 213ceaa1..4935c11d 100644 --- a/helpers/training/quantisation/quanto_workarounds.py +++ b/helpers/training/quantisation/quanto_workarounds.py @@ -6,8 +6,8 @@ import optimum from optimum.quanto.library.extensions.cuda import ext as quanto_ext - # torch tells us to do this because - torch._dynamo.config.optimize_ddp=False + # torch tells us to do this because + torch._dynamo.config.optimize_ddp = False # Save the original operator original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin @@ -64,23 +64,31 @@ def forward(ctx, input, other, bias): tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction - class WeightQBytesLinearFunction(optimum.quanto.tensor.function.QuantizedLinearFunction): + class WeightQBytesLinearFunction( + optimum.quanto.tensor.function.QuantizedLinearFunction + ): @staticmethod def forward(ctx, input, other, bias=None): ctx.save_for_backward(input, other) if isinstance(input, optimum.quanto.tensor.QBytesTensor): - output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale) + output = torch.ops.quanto.qbytes_mm( + input._data, other._data, input._scale * other._scale + ) else: in_features = input.shape[-1] out_features = other.shape[0] output_shape = input.shape[:-1] + (out_features,) - output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale) + output = torch.ops.quanto.qbytes_mm( + input.reshape(-1, in_features), other._data, other._scale + ) output = output.view(output_shape) if bias is not None: output = output + bias return output - optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = WeightQBytesLinearFunction + optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = ( + WeightQBytesLinearFunction + ) def reshape_qlf_backward(ctx, gO): # another one where we need .reshape instead of .view @@ -92,11 +100,16 @@ def reshape_qlf_backward(ctx, gO): input_gO = torch.matmul(gO, other) if ctx.needs_input_grad[1]: # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A - other_gO = torch.matmul(gO.reshape(-1, out_features).t(), input.reshape(-1, in_features)) + other_gO = torch.matmul( + gO.reshape(-1, out_features).t(), + input.to(g0.dtype).reshape(-1, in_features), + ) if ctx.needs_input_grad[2]: # Bias gradient is the sum on all dimensions but the last one dim = tuple(range(gO.ndim - 1)) bias_gO = gO.sum(dim) return input_gO, other_gO, bias_gO - - optimum.quanto.tensor.function.QuantizedLinearFunction.backward = reshape_qlf_backward + + optimum.quanto.tensor.function.QuantizedLinearFunction.backward = ( + reshape_qlf_backward + ) diff --git a/helpers/training/quantisation/torchao_workarounds.py b/helpers/training/quantisation/torchao_workarounds.py index 4887c5b9..633bc689 100644 --- a/helpers/training/quantisation/torchao_workarounds.py +++ b/helpers/training/quantisation/torchao_workarounds.py @@ -32,7 +32,7 @@ def backward(ctx, grad_output): # here is the patch: we will cast the input to the grad_output dtype. grad_weight = grad_output.view(-1, weight.shape[0]).T @ input.to( grad_output.dtype - ).view(-1, weight.shape[1]) + ).reshape(-1, weight.shape[1]) grad_bias = grad_output.view(-1, weight.shape[0]).sum(0) if ctx.bias else None return grad_input, grad_weight, grad_bias From b0196cf03be3647c847bb89629a91e49aa894dc8 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Fri, 4 Oct 2024 18:45:01 -0600 Subject: [PATCH 2/4] update min req for hw --- documentation/quickstart/FLUX.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 52df5ad9..2f7886d9 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -12,10 +12,11 @@ When you're training every component of a rank-16 LoRA (MLP, projections, multim - a bit more than 30G VRAM when not quantising the base model - a bit more than 18G VRAM when quantising to int8 + bf16 base/LoRA weights - a bit more than 13G VRAM when quantising to int4 + bf16 base/LoRA weights +- a bit more than 9G VRAM when quantising to NF4 + bf16 base/LoRA weights - a bit more than 9G VRAM when quantising to int2 + bf16 base/LoRA weights You'll need: -- **the absolute minimum** is a single 4060 Ti 16GB +- **the absolute minimum** is a single **3080 10G** - **a realistic minimum** is a single 3090 or V100 GPU - **ideally** multiple 4090, A6000, L40S, or better From 3f40593c54ff5b1141ddc5584274d3cc87b2a338 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 04:29:48 +0100 Subject: [PATCH 3/4] flux quickstart nf4 notes expansion --- documentation/quickstart/FLUX.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 2f7886d9..8e00985e 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -388,6 +388,22 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: Speed was approximately 1.4 iterations per second on a 4090. +### NF4-quantised training + +In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address. + +In early tests, the following holds true: +- Lion optimiser causes model collapse but uses least VRAM; AdamW variants help to hold it together; bnb-adamw8bit, adamw_bf16 are great choices + - AdEMAMix didn't fare well, but settings were not explored +- `--max_grad_norm=0.01` further helps reduce model breakage by preventing huge changes to the model in too short a time +- NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used +- Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090) +- Anything that's difficult to train on int8 or bf16 becomes harder in NF4 + +NF4 does not work with torch.compile, so whatever you get for speed is what you get. + +If VRAM is not a concern (eg. 48G or greater) then int8 with torch.compile is your best, fastest option. + ### Classifier-free guidance #### Problem From fc61bf8ca9479f47590232827c8ad5611232c317 Mon Sep 17 00:00:00 2001 From: Jimmy <39@🇺🇸.com> Date: Sat, 5 Oct 2024 20:21:26 -0400 Subject: [PATCH 4/4] Fix attention masking transformer for flux --- helpers/models/flux/transformer.py | 257 +++++++--------------------- helpers/training/diffusion_model.py | 6 +- 2 files changed, 62 insertions(+), 201 deletions(-) diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index be6d1cd6..4b1dacbe 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -3,7 +3,7 @@ # Originally licensed under the Apache License, Version 2.0 (the "License"); # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -14,7 +14,6 @@ from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import ( Attention, - apply_rope, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import ( @@ -33,6 +32,7 @@ from diffusers.models.embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, + FluxPosEmbed, ) from diffusers.models.modeling_outputs import Transformer2DModelOutput @@ -40,90 +40,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FluxSingleAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, _, _ = hidden_states.shape - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) - - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = (attention_mask > 0).bool() - attention_mask = attention_mask.to( - device=hidden_states.device, dtype=hidden_states.dtype - ) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - dropout_p=0.0, - is_causal=False, - attn_mask=attention_mask, - ) - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -141,20 +57,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) @@ -173,41 +76,42 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) @@ -229,71 +133,20 @@ def __call__( ) hidden_states = hidden_states.to(query.dtype) - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], ) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states, encoder_hidden_states - -# YiYi to-do: refactor rope related functions/classes -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - scale = ( - torch.arange( - 0, - dim, - 2, - dtype=torch.float32, - device=pos.device, - ) - / dim - ) - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() - - -# YiYi to-do: refactor rope related functions/classes -class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: List[int]): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - dim=-3, - ) - - return emb.unsqueeze(1) + return hidden_states, encoder_hidden_states + return hidden_states def expand_flux_attention_mask( @@ -338,7 +191,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - processor = FluxSingleAttnProcessor2_0() + processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -536,7 +389,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], + axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() self.out_channels = in_channels @@ -544,8 +397,8 @@ def __init__( self.config.num_attention_heads * self.config.attention_head_dim ) - self.pos_embed = EmbedND( - dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope + self.pos_embed = FluxPosEmbed( + theta=10000, axes_dim=axes_dims_rope ) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings @@ -667,7 +520,13 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - ids = torch.cat((txt_ids, img_ids), dim=1) + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index ad707599..9de806b8 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -18,6 +18,7 @@ def load_diffusion_model(args, weight_dtype): "revision": args.revision, "variant": args.variant, "torch_dtype": weight_dtype, + "use_safetensors": True, } unet = None transformer = None @@ -64,8 +65,9 @@ def load_diffusion_model(args, weight_dtype): ) transformer = FluxTransformer2DModelWithMasking.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", + args.pretrained_transformer_model_name_or_path + or args.pretrained_model_name_or_path, + subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) elif args.model_family == "pixart_sigma":