Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #1089

Merged
merged 4 commits into from
Oct 23, 2024
Merged

merge #1089

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,9 +1313,11 @@ def compute_embeddings_for_sd3_prompts(
)
add_text_embeds = pooled_prompt_embeds
# StabilityAI say not to zero them out.
# if prompt == "":
# prompt_embeds = torch.zeros_like(prompt_embeds)
# add_text_embeds = torch.zeros_like(add_text_embeds)
if prompt == "":
if StateTracker.get_args().sd3_clip_uncond_behaviour == "zero":
prompt_embeds = torch.zeros_like(prompt_embeds)
if StateTracker.get_args().sd3_t5_uncond_behaviour == "zero":
add_text_embeds = torch.zeros_like(add_text_embeds)
# Get the current size of the queue.
current_size = self.write_queue.qsize()
if current_size >= 2048:
Expand Down
24 changes: 24 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,27 @@ def get_argument_parser():
" Additionally, 'diffusion' is offered as an option to reparameterise a model to v_prediction loss."
),
)
parser.add_argument(
"--sd3_clip_uncond_behaviour",
type=str,
choices=["empty_string", "zero"],
default='empty_string',
help=(
"SD3 can be trained using zeroed prompt embeds during unconditional dropout,"
" or an encoded empty string may be used instead (the default). Changing this value may stabilise or"
" destabilise training. The default is 'empty_string'."
)
)
parser.add_argument(
"--sd3_t5_uncond_behaviour",
type=str,
choices=["empty_string", "zero"],
default=None,
help=(
"Override the value of unconditional prompts from T5 embeds."
" The default is to follow the value of --sd3_clip_uncond_behaviour."
)
)
parser.add_argument(
"--sd3_t5_mask_behaviour",
type=str,
Expand Down Expand Up @@ -2098,6 +2119,9 @@ def parse_cmdline_args(input_args=None):
"MM-DiT requires an alignment value of 64px. Overriding the value of --aspect_bucket_alignment."
)
args.aspect_bucket_alignment = 64
if args.sd3_t5_uncond_behaviour is None:
args.sd3_t5_uncond_behaviour = args.sd3_clip_uncond_behaviour
info_log(f"SD3 embeds for unconditional captions: t5={args.sd3_t5_uncond_behaviour}, clip={args.sd3_clip_uncond_behaviour}")
elif "deepfloyd" in args.model_type:
deepfloyd_pixel_alignment = 8
if args.aspect_bucket_alignment != deepfloyd_pixel_alignment:
Expand Down
7 changes: 6 additions & 1 deletion helpers/training/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,13 @@ def load_diffusion_model(args, weight_dtype):
from flash_attn_interface import flash_attn_func
import diffusers

from helpers.models.flux.attention import (
FluxAttnProcessor3_0,
FluxSingleAttnProcessor3_0,
)

diffusers.models.attention_processor.FluxSingleAttnProcessor2_0 = (
FluxAttnProcessor3_0
FluxSingleAttnProcessor3_0
)
diffusers.models.attention_processor.FluxAttnProcessor2_0 = (
FluxAttnProcessor3_0
Expand Down
48 changes: 27 additions & 21 deletions helpers/training/quantisation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from helpers.training.multi_process import should_log
from helpers.training.state_tracker import StateTracker
import logging
import torch, os

Expand Down Expand Up @@ -72,30 +73,35 @@ def _quanto_model(
logger.info(f"Quantising {model.__class__.__name__}. Using {model_precision}.")
weight_quant = _quanto_type_map(model_precision)
extra_quanto_args = {}
extra_quanto_args["exclude"] = [
"*.norm",
"*.norm1",
"*.norm1_context",
"*.norm_q",
"*.norm_k",
"*.norm_added_q",
"*.norm_added_k",
"proj_out",
"pos_embed",
"norm_out",
"context_embedder",
"time_text_embed",
]
if StateTracker.get_args().model_family == "sd3":
extra_quanto_args["exclude"] = [
"*.norm",
"*.norm1",
"*.norm1_context",
"*.norm_q",
"*.norm_k",
"*.norm_added_q",
"*.norm_added_k",
"proj_out",
"pos_embed",
"norm_out",
"context_embedder",
"time_text_embed",
]
elif StateTracker.get_args().model_family == "flux":
extra_quanto_args["exclude"] = [
"*.norm",
"*.norm1",
"*.norm2",
"*.norm2_context",
"proj_out",
"x_embedder",
"norm_out",
"context_embedder",
]
if quantize_activations:
logger.info("Freezing model weights and activations")
extra_quanto_args["activations"] = weight_quant
# extra_quanto_args["exclude"] = [
# "*.norm",
# "*.norm1",
# "*.norm2",
# "*.norm2_context",
# "proj_out",
# ]
else:
logger.info("Freezing model weights only")

Expand Down
Loading