Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux: mobius-style training via augmented guidance scale #880

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,17 @@ def parse_args(input_args=None):
parser.add_argument(
"--flux_guidance_mode",
type=str,
choices=["constant", "random-range"],
choices=["constant", "random-range", "mobius"],
default="constant",
help=(
"Flux has a 'guidance' value used during training time that reflects the CFG range of your training samples."
" The default mode 'constant' will use a single value for every sample."
" The mode 'random-range' will randomly select a value from the range of the CFG for each sample."
" The mode 'mobius' will use a value that is a function of the remaining steps in the epoch, constructively"
" deconstructing the constructed deconstructions to then Mobius them back into the constructed reconstructions,"
" possibly resulting in the exploration of what is known as the Mobius space, a new continuous"
" realm of possibility brought about by destroying the model so that you can make it whole once more."
" Or so according to DataVoid, anyway. This is just a Flux-specific implementation of Mobius."
" Set the range using --flux_guidance_min and --flux_guidance_max."
),
)
Expand Down Expand Up @@ -2071,6 +2076,16 @@ def parse_args(input_args=None):
"Flux Schnell requires fewer inference steps. Consider reducing --validation_num_inference_steps to 4."
)

if args.flux_guidance_mode == "mobius":
logger.warning(
"Mobius training is only for the most elite. Pardon my English, but this is not for those who don't like to destroy something beautiful every now and then. If you feel perhaps this is not for you, please consider using a different guidance mode."
)
if args.flux_guidance_min < 1.0:
logger.warning(
"Flux minimum guidance value for Mobius training is 1.0. Updating value.."
)
args.flux_guidance_min = 1.0

if args.use_ema and args.ema_cpu_only:
args.ema_device = "cpu"

Expand Down
33 changes: 33 additions & 0 deletions helpers/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,38 @@
import torch
import random
from helpers.models.flux.pipeline import FluxPipeline
from helpers.training import steps_remaining_in_epoch


def get_mobius_guidance(args, global_step, steps_per_epoch, batch_size, device):
"""
state of the art
"""
steps_remaining = steps_remaining_in_epoch(global_step, steps_per_epoch)

# Start with a linear mapping from remaining steps to a scale between 0 and 1
scale_factor = steps_remaining / steps_per_epoch

# we want the last 10% of the epoch to have a guidance of 1.0
threshold_step_count = max(1, int(steps_per_epoch * 0.1))

if (
steps_remaining <= threshold_step_count
): # Last few steps in the epoch, set guidance to 1.0
guidance_values = torch.ones(batch_size, device=device)
else:
# Sample between flux_guidance_min and flux_guidance_max with bias towards 1.0
guidance_values = torch.tensor(
[
random.uniform(args.flux_guidance_min, args.flux_guidance_max)
* scale_factor
+ (1.0 - scale_factor)
for _ in range(batch_size)
],
device=device,
)

return guidance_values


def update_flux_schedule_to_fast(args, noise_scheduler_to_copy):
Expand Down
15 changes: 15 additions & 0 deletions helpers/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,18 @@
},
},
}


def steps_remaining_in_epoch(current_step: int, steps_per_epoch: int) -> int:
"""
Calculate the number of steps remaining in the current epoch.

Args:
current_step (int): The current step within the epoch.
steps_per_epoch (int): Total number of steps in the epoch.

Returns:
int: Number of steps remaining in the current epoch.
"""
remaining_steps = steps_per_epoch - (current_step % steps_per_epoch)
return remaining_steps
29 changes: 26 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from helpers.training.wrappers import unwrap_model
from helpers.data_backend.factory import configure_multi_databackend
from helpers.data_backend.factory import random_dataloader_iterator
from helpers.training import steps_remaining_in_epoch
from helpers.training.custom_schedule import (
generate_timestep_weights,
segmented_timestep_selection,
Expand Down Expand Up @@ -119,6 +120,7 @@
prepare_latent_image_ids,
pack_latents,
unpack_latents,
get_mobius_guidance,
)

is_optimi_available = False
Expand Down Expand Up @@ -1821,12 +1823,33 @@ def main():
height=latents.shape[2],
width=latents.shape[3],
)
if args.flux_guidance_mode == "constant":
if args.flux_guidance_mode == "mobius":
guidance = get_mobius_guidance(
args,
global_step,
num_update_steps_per_epoch,
latents.shape[0],
accelerator.device,
)
elif args.flux_guidance_mode == "constant":
guidance_scale = float(args.flux_guidance_value)
guidance = torch.tensor(
[guidance_scale], device=accelerator.device
).expand(latents.shape[0])

elif args.flux_guidance_mode == "random-range":
guidance_scale = random.uniform(
args.flux_guidance_min, args.flux_guidance_max
# Generate a list of random values within the specified range for each latent
guidance_scales = [
random.uniform(
args.flux_guidance_min, args.flux_guidance_max
)
for _ in range(latents.shape[0])
]
guidance = torch.tensor(
guidance_scales, device=accelerator.device
)

# Now `guidance` will have different values for each latent in `latents`.
transformer_config = None
if hasattr(transformer, "module"):
transformer_config = transformer.module.config
Expand Down
Loading