diff --git a/documentation/DATALOADER.md b/documentation/DATALOADER.md index 4dc915cc..eb239026 100644 --- a/documentation/DATALOADER.md +++ b/documentation/DATALOADER.md @@ -189,6 +189,12 @@ Images are not resized before cropping **unless** `maximum_image_size` and `targ > ℹī¸ This value behaves differently to the same option in Kohya's scripts, where a value of 1 means no repeats. **For SimpleTuner, a value of 0 means no repeats**. Subtract one from your Kohya config value to obtain the equivalent for SimpleTuner, hence a value of **9** resulting from the calculation `(dataset_length + repeats * dataset_length)` . +### `is_regularisation_data` + +- Also may be spelt `is_regularization_data` +- Enables parent-teacher training for LyCORIS adapters so that the prediction target prefers the base model's result for a given dataset. + - Standard LoRA are not currently supported. + ### `vae_cache_clear_each_epoch` - When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets, but combined with `crop_style=random` and/or `crop_aspect=random` you'll want this enabled to ensure you sample a full range of crops from each image. diff --git a/documentation/DREAMBOOTH.md b/documentation/DREAMBOOTH.md index 615570bb..c276090b 100644 --- a/documentation/DREAMBOOTH.md +++ b/documentation/DREAMBOOTH.md @@ -28,7 +28,7 @@ Since that time, the idea has evolved and debated, with an opposing camp decidin The model contains something called a "prior" which could, in theory, be preserved during Dreambooth training. In experiments with Stable Diffusion however, it didn't seem to help - the model just overfits on its own knowledge. -> 🔴 Prior preservation loss is not supported in SimpleTuner, all regularisation data is treated as if it were usual training data. +> đŸŸĸ ([#1031](https://github.com/bghira/SimpleTuner/issues/1031)) Prior preservation loss is supported in SimpleTuner when training LyCORIS adapters by setting `is_regularisation_data` on that dataset. ## Setup @@ -105,7 +105,8 @@ Inside our dataloader config `multidatabackend-dreambooth.json`, it will look so "repeats": 0, "resolution": 512, "resolution_type": "pixel_area", - "minimum_image_size": 192 + "minimum_image_size": 192, + "is_regularisation_data": true }, { "id": "regularisation-data-1024px", @@ -117,7 +118,8 @@ Inside our dataloader config `multidatabackend-dreambooth.json`, it will look so "repeats": 0, "resolution": 1024, "resolution_type": "pixel_area", - "minimum_image_size": 768 + "minimum_image_size": 768, + "is_regularisation_data": true }, { "id": "textembeds", @@ -143,6 +145,8 @@ For a regularisation dataset: - If your Regularisation set has 1000 images, and you have 10 images in your training set, you'd want a repeats value of at least 100 to get fast results - `minimum_image_size` has been increased to ensure we don't introduce too many low-quality artifacts - Similarly, using more descriptive captions may help avoid forgetting. Switching from `instanceprompt` to `textfile` or other strategies will require creating `.txt` files for each image. +- When `is_regularisation_data` (or đŸ‡ē🇸 `is_regularization_data` with a z, for the American users) is set, the data from this set will be fed into the base model to obtain a prediction that can be used as a loss target for the student LyCORIS model. + - Note, currently this only functions on a LyCORIS adapter. ## Selecting an instance prompt diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 6326e012..1e55fee1 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -262,7 +262,8 @@ Create a `--data_backend_config` (`config/multidatabackend.json`) document conta "skip_file_discovery": "", "caption_strategy": "filename", "metadata_backend": "discovery", - "repeats": 0 + "repeats": 0, + "is_regularisation_data": true }, { "id": "dreambooth-subject", @@ -469,6 +470,7 @@ We can partially reintroduce distillation to a de-distilled model by continuing #### LoKr (--lora_type=lycoris) - Higher learning rates are better for LoKr (`1e-3` with AdamW, `2e-4` with Lion) - Other algo need more exploration. +- Setting `is_regularisation_data` on such datasets may help preserve / prevent bleed. ### Image artifacts Flux will immediately absorb bad image artifacts. It's just how it is - a final training run on just high quality data may be required to fix it at the end. diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 6c9f37b9..f80c11c3 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -835,6 +835,9 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize ) use_captions = True + is_regularisation_data = backend.get( + "is_regularisation_data", backend.get("is_regularization_data", False) + ) if "only_instance_prompt" in backend and backend["only_instance_prompt"]: use_captions = False elif args.only_instance_prompt: @@ -842,6 +845,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize init_backend["train_dataset"] = MultiAspectDataset( id=init_backend["id"], datasets=[init_backend["metadata_backend"]], + is_regularisation_data=is_regularisation_data, ) if "deepfloyd" in args.model_type: diff --git a/helpers/multiaspect/dataset.py b/helpers/multiaspect/dataset.py index 21536ba6..a04da36a 100644 --- a/helpers/multiaspect/dataset.py +++ b/helpers/multiaspect/dataset.py @@ -20,11 +20,13 @@ def __init__( self, id: str, datasets: list, - print_names=False, + print_names: bool = False, + is_regularisation_data: bool = False, ): self.id = id self.datasets = datasets self.print_names = print_names + self.is_regularisation_data = is_regularisation_data def __len__(self): # Sum the length of all data backends: @@ -34,6 +36,7 @@ def __getitem__(self, image_tuple): output_data = { "training_samples": [], "conditioning_samples": [], + "is_regularisation_data": self.is_regularisation_data, } first_aspect_ratio = None for sample in image_tuple: diff --git a/helpers/training/collate.py b/helpers/training/collate.py index a892c954..3ece5331 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -375,6 +375,7 @@ def collate_fn(batch): batch = batch[0] examples = batch["training_samples"] conditioning_examples = batch["conditioning_samples"] + is_regularisation_data = batch.get("is_regularisation_data", False) if StateTracker.get_args().controlnet and len(examples) != len( conditioning_examples ): @@ -475,4 +476,5 @@ def collate_fn(batch): "batch_luminance": batch_luminance, "conditioning_pixel_values": conditioning_latents, "encoder_attention_mask": attn_mask, + "is_regularisation_data": is_regularisation_data, } diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index e6779e6e..15e3ae81 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -96,3 +96,13 @@ def safety_check(args, accelerator): f"Your GPU has {total_memory_gb}GB of memory. The SOAP optimiser requires a GPU with at least 24G of memory." ) sys.exit(1) + + if ( + args.model_type != "lora" + and args.base_model_precision != "no_change" + and not args.i_know_what_i_am_doing + ): + logger.error( + f"{args.model_type} tuning is not compatible with quantisation. Please set --base_model_precision to 'no_change' or train LyCORIS/LoRA." + ) + sys.exit(1) diff --git a/helpers/training/quantisation/quanto_workarounds.py b/helpers/training/quantisation/quanto_workarounds.py index 8890064e..254a4dbe 100644 --- a/helpers/training/quantisation/quanto_workarounds.py +++ b/helpers/training/quantisation/quanto_workarounds.py @@ -1,9 +1,10 @@ import torch +import optimum + if torch.cuda.is_available(): # the marlin fp8 kernel needs some help with dtype casting for some reason # see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201 - import optimum from optimum.quanto.library.extensions.cuda import ext as quanto_ext # Save the original operator @@ -62,52 +63,54 @@ def forward(ctx, input, other, bias): tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction - class WeightQBytesLinearFunction( - optimum.quanto.tensor.function.QuantizedLinearFunction - ): - @staticmethod - def forward(ctx, input, other, bias=None): - ctx.save_for_backward(input, other) - if isinstance(input, optimum.quanto.tensor.QBytesTensor): - output = torch.ops.quanto.qbytes_mm( - input._data, other._data, input._scale * other._scale - ) - else: - in_features = input.shape[-1] - out_features = other.shape[0] - output_shape = input.shape[:-1] + (out_features,) - output = torch.ops.quanto.qbytes_mm( - input.reshape(-1, in_features), other._data, other._scale - ) - output = output.view(output_shape) - if bias is not None: - output = output + bias - return output - optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = ( - WeightQBytesLinearFunction - ) - - def reshape_qlf_backward(ctx, gO): - # another one where we need .reshape instead of .view - input_gO = other_gO = bias_gO = None - input, other = ctx.saved_tensors - out_features, in_features = other.shape - if ctx.needs_input_grad[0]: - # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B - input_gO = torch.matmul(gO, other) - if ctx.needs_input_grad[1]: - # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A - other_gO = torch.matmul( - gO.reshape(-1, out_features).t(), - input.to(g0.dtype).reshape(-1, in_features), +class WeightQBytesLinearFunction( + optimum.quanto.tensor.function.QuantizedLinearFunction +): + @staticmethod + def forward(ctx, input, other, bias=None): + ctx.save_for_backward(input, other) + if isinstance(input, optimum.quanto.tensor.QBytesTensor): + output = torch.ops.quanto.qbytes_mm( + input._data, other._data, input._scale * other._scale + ) + else: + in_features = input.shape[-1] + out_features = other.shape[0] + output_shape = input.shape[:-1] + (out_features,) + output = torch.ops.quanto.qbytes_mm( + input.reshape(-1, in_features), other._data, other._scale ) - if ctx.needs_input_grad[2]: - # Bias gradient is the sum on all dimensions but the last one - dim = tuple(range(gO.ndim - 1)) - bias_gO = gO.sum(dim) - return input_gO, other_gO, bias_gO - - optimum.quanto.tensor.function.QuantizedLinearFunction.backward = ( - reshape_qlf_backward - ) + output = output.view(output_shape) + if bias is not None: + output = output + bias + return output + + +optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = ( + WeightQBytesLinearFunction +) + + +def reshape_qlf_backward(ctx, gO): + # another one where we need .reshape instead of .view + input_gO = other_gO = bias_gO = None + input, other = ctx.saved_tensors + out_features, in_features = other.shape + if ctx.needs_input_grad[0]: + # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B + input_gO = torch.matmul(gO, other) + if ctx.needs_input_grad[1]: + # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A + other_gO = torch.matmul( + gO.reshape(-1, out_features).t(), + input.to(g0.dtype).reshape(-1, in_features), + ) + if ctx.needs_input_grad[2]: + # Bias gradient is the sum on all dimensions but the last one + dim = tuple(range(gO.ndim - 1)) + bias_gO = gO.sum(dim) + return input_gO, other_gO, bias_gO + + +optimum.quanto.tensor.function.QuantizedLinearFunction.backward = reshape_qlf_backward diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 6f321ffe..5f452eb2 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1783,6 +1783,251 @@ def abort(self): backend["sampler"].should_abort = True self.should_abort = True + def model_predict( + self, + batch, + latents, + noisy_latents, + encoder_hidden_states, + added_cond_kwargs, + add_text_embeds, + timesteps, + ): + if self.config.controlnet: + training_logger.debug( + f"Extra conditioning dtype: {batch['conditioning_pixel_values'].dtype}" + ) + if not self.config.disable_accelerator: + if self.config.controlnet: + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to( + dtype=self.config.weight_dtype + ) + training_logger.debug(f"Image shape: {controlnet_image.shape}") + down_block_res_samples, mid_block_res_sample = self.controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + controlnet_cond=controlnet_image, + return_dict=False, + ) + # Predict the noise residual + if self.unet is not None: + model_pred = self.unet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=[ + sample.to(dtype=self.config.weight_dtype) + for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to( + dtype=self.config.weight_dtype + ), + return_dict=False, + )[0] + if self.transformer is not None: + raise Exception( + "ControlNet predictions for transformer models are not yet implemented." + ) + elif self.config.model_family == "flux": + # handle guidance + packed_noisy_latents = pack_latents( + noisy_latents, + batch_size=latents.shape[0], + num_channels_latents=latents.shape[1], + height=latents.shape[2], + width=latents.shape[3], + ).to( + dtype=self.config.base_weight_dtype, + device=self.accelerator.device, + ) + if self.config.flux_guidance_mode == "mobius": + guidance_scales = get_mobius_guidance( + self.config, + self.state["global_step"], + self.config.num_update_steps_per_epoch, + latents.shape[0], + self.accelerator.device, + ) + elif self.config.flux_guidance_mode == "constant": + guidance_scales = [ + float(self.config.flux_guidance_value) + ] * latents.shape[0] + + elif self.config.flux_guidance_mode == "random-range": + # Generate a list of random values within the specified range for each latent + guidance_scales = [ + random.uniform( + self.config.flux_guidance_min, + self.config.flux_guidance_max, + ) + for _ in range(latents.shape[0]) + ] + self.guidance_values_list.append(guidance_scales) + + # Now `guidance` will have different values for each latent in `latents`. + transformer_config = None + if hasattr(self.transformer, "module"): + transformer_config = self.transformer.module.config + elif hasattr(self.transformer, "config"): + transformer_config = self.transformer.config + if transformer_config is not None and getattr( + transformer_config, "guidance_embeds", False + ): + guidance = torch.tensor( + guidance_scales, device=self.accelerator.device + ) + else: + guidance = None + img_ids = prepare_latent_image_ids( + latents.shape[0], + latents.shape[2], + latents.shape[3], + self.accelerator.device, + self.config.weight_dtype, + ) + timesteps = ( + torch.tensor(timesteps) + .expand(noisy_latents.shape[0]) + .to(device=self.accelerator.device) + / 1000 + ) + + text_ids = torch.zeros( + batch["prompt_embeds"].shape[1], + 3, + ).to( + device=self.accelerator.device, + dtype=self.config.base_weight_dtype, + ) + training_logger.debug( + "DTypes:" + f"\n-> Text IDs shape: {text_ids.shape if hasattr(text_ids, 'shape') else None}, dtype: {text_ids.dtype if hasattr(text_ids, 'dtype') else None}" + f"\n-> Image IDs shape: {img_ids.shape if hasattr(img_ids, 'shape') else None}, dtype: {img_ids.dtype if hasattr(img_ids, 'dtype') else None}" + f"\n-> Timesteps shape: {timesteps.shape if hasattr(timesteps, 'shape') else None}, dtype: {timesteps.dtype if hasattr(timesteps, 'dtype') else None}" + f"\n-> Guidance: {guidance}" + f"\n-> Packed Noisy Latents shape: {packed_noisy_latents.shape if hasattr(packed_noisy_latents, 'shape') else None}, dtype: {packed_noisy_latents.dtype if hasattr(packed_noisy_latents, 'dtype') else None}" + ) + + flux_transformer_kwargs = { + "hidden_states": packed_noisy_latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + "timestep": timesteps, + "guidance": guidance, + "pooled_projections": batch["add_text_embeds"].to( + device=self.accelerator.device, + dtype=self.config.base_weight_dtype, + ), + "encoder_hidden_states": batch["prompt_embeds"].to( + device=self.accelerator.device, + dtype=self.config.base_weight_dtype, + ), + "txt_ids": text_ids.to( + device=self.accelerator.device, + dtype=self.config.base_weight_dtype, + ), + "img_ids": img_ids, + "joint_attention_kwargs": None, + "return_dict": False, + } + if self.config.flux_attention_masked_training: + flux_transformer_kwargs["attention_mask"] = batch[ + "encoder_attention_mask" + ] + if flux_transformer_kwargs["attention_mask"] is None: + raise ValueError( + "No attention mask was discovered when attempting validation - this means you need to recreate your text embed cache." + ) + + model_pred = self.transformer(**flux_transformer_kwargs)[0] + + elif self.config.model_family == "sd3": + # Stable Diffusion 3 uses a MM-DiT model where the VAE-produced + # image embeds are passed in with the TE-produced text embeds. + model_pred = self.transformer( + hidden_states=noisy_latents.to( + device=self.accelerator.device, + dtype=self.config.base_weight_dtype, + ), + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states.to( + device=self.accelerator.device, + dtype=self.config.base_weight_dtype, + ), + pooled_projections=add_text_embeds.to( + device=self.accelerator.device, + dtype=self.config.weight_dtype, + ), + return_dict=False, + )[0] + elif self.config.model_family == "pixart_sigma": + model_pred = self.transformer( + noisy_latents, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=batch["encoder_attention_mask"], + timestep=timesteps, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + model_pred = model_pred.chunk(2, dim=1)[0] + elif self.config.model_family == "smoldit": + first_latent_shape = noisy_latents.shape + height = first_latent_shape[1] * 8 + width = first_latent_shape[2] * 8 + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size + ) + inputs = { + "hidden_states": noisy_latents, + "timestep": timesteps, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": batch["encoder_attention_mask"], + "image_rotary_emb": get_2d_rotary_pos_embed( + self.transformer.inner_dim + // self.transformer.config.num_attention_heads, + grid_crops_coords, + (grid_height, grid_width), + ), + } + model_pred = self.transformer(**inputs).sample + elif self.unet is not None: + if self.config.model_family == "legacy": + # SD 1.5 or 2.x + model_pred = self.unet( + noisy_latents, + timesteps, + encoder_hidden_states, + ).sample + else: + # SDXL, Kolors, other default unet prediction. + model_pred = self.unet( + noisy_latents, + timesteps, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ).sample + else: + raise Exception("Unknown error occurred, no prediction could be made.") + + if self.config.model_family == "flux": + model_pred = unpack_latents( + model_pred, + height=latents.shape[2] * 8, + width=latents.shape[3] * 8, + vae_scale_factor=16, + ) + else: + # Dummy model prediction for debugging. + model_pred = torch.randn_like(noisy_latents) + + return model_pred + def train(self): self.init_trackers() self._train_initial_msg() @@ -2088,11 +2333,9 @@ def train(self): "Supported types are 'epsilon', `sample`, and 'v_prediction'." ) + added_cond_kwargs = None # Predict the noise residual and compute loss - if self.config.model_family == "sd3": - # Even if we're using DDPM process, we don't add in extra kwargs, which are SDXL-specific. - added_cond_kwargs = None - elif ( + if ( StateTracker.get_model_family() == "sdxl" or self.config.model_family == "kolors" ): @@ -2120,259 +2363,47 @@ def train(self): dtype=self.config.weight_dtype, ) - training_logger.debug("Predicting noise residual.") - - if self.config.controlnet: - training_logger.debug( - f"Extra conditioning dtype: {batch['conditioning_pixel_values'].dtype}" - ) - if not self.config.disable_accelerator: - if self.config.controlnet: - # ControlNet conditioning. - controlnet_image = batch["conditioning_pixel_values"].to( - dtype=self.config.weight_dtype - ) - training_logger.debug( - f"Image shape: {controlnet_image.shape}" - ) - down_block_res_samples, mid_block_res_sample = ( - self.controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - added_cond_kwargs=added_cond_kwargs, - controlnet_cond=controlnet_image, - return_dict=False, - ) - ) - # Predict the noise residual - if self.unet is not None: - model_pred = self.unet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=[ - sample.to(dtype=self.config.weight_dtype) - for sample in down_block_res_samples - ], - mid_block_additional_residual=mid_block_res_sample.to( - dtype=self.config.weight_dtype - ), - return_dict=False, - )[0] - if self.transformer is not None: - raise Exception( - "ControlNet predictions for transformer models are not yet implemented." - ) - elif self.config.model_family == "flux": - # handle guidance - packed_noisy_latents = pack_latents( - noisy_latents, - batch_size=latents.shape[0], - num_channels_latents=latents.shape[1], - height=latents.shape[2], - width=latents.shape[3], - ).to( - dtype=self.config.base_weight_dtype, - device=self.accelerator.device, - ) - if self.config.flux_guidance_mode == "mobius": - guidance_scales = get_mobius_guidance( - self.config, - self.state["global_step"], - self.config.num_update_steps_per_epoch, - latents.shape[0], - self.accelerator.device, - ) - elif self.config.flux_guidance_mode == "constant": - guidance_scales = [ - float(self.config.flux_guidance_value) - ] * latents.shape[0] - - elif self.config.flux_guidance_mode == "random-range": - # Generate a list of random values within the specified range for each latent - guidance_scales = [ - random.uniform( - self.config.flux_guidance_min, - self.config.flux_guidance_max, - ) - for _ in range(latents.shape[0]) - ] - self.guidance_values_list.append(guidance_scales) - - # Now `guidance` will have different values for each latent in `latents`. - transformer_config = None - if hasattr(self.transformer, "module"): - transformer_config = self.transformer.module.config - elif hasattr(self.transformer, "config"): - transformer_config = self.transformer.config - if transformer_config is not None and getattr( - transformer_config, "guidance_embeds", False - ): - guidance = torch.tensor( - guidance_scales, device=self.accelerator.device + # a marker to know whether we had a model capable of regularised data training. + handled_regularisation = False + is_regularisation_data = batch.get("is_regularisation_data", False) + if is_regularisation_data and self.config.model_type == "lora": + training_logger.debug("Predicting parent model residual.") + handled_regularisation = True + with torch.no_grad(): + if self.config.lora_type.lower() == "lycoris": + logger.info( + "Detaching LyCORIS adapter for parent prediction." ) + self.accelerator._lycoris_wrapped_network.restore() else: - guidance = None - img_ids = prepare_latent_image_ids( - latents.shape[0], - latents.shape[2], - latents.shape[3], - self.accelerator.device, - self.config.weight_dtype, - ) - timesteps = ( - torch.tensor(timesteps) - .expand(noisy_latents.shape[0]) - .to(device=self.accelerator.device) - / 1000 - ) - - text_ids = torch.zeros( - batch["prompt_embeds"].shape[1], - 3, - ).to( - device=self.accelerator.device, - dtype=self.config.base_weight_dtype, - ) - training_logger.debug( - "DTypes:" - f"\n-> Text IDs shape: {text_ids.shape if hasattr(text_ids, 'shape') else None}, dtype: {text_ids.dtype if hasattr(text_ids, 'dtype') else None}" - f"\n-> Image IDs shape: {img_ids.shape if hasattr(img_ids, 'shape') else None}, dtype: {img_ids.dtype if hasattr(img_ids, 'dtype') else None}" - f"\n-> Timesteps shape: {timesteps.shape if hasattr(timesteps, 'shape') else None}, dtype: {timesteps.dtype if hasattr(timesteps, 'dtype') else None}" - f"\n-> Guidance: {guidance}" - f"\n-> Packed Noisy Latents shape: {packed_noisy_latents.shape if hasattr(packed_noisy_latents, 'shape') else None}, dtype: {packed_noisy_latents.dtype if hasattr(packed_noisy_latents, 'dtype') else None}" - ) - - flux_transformer_kwargs = { - "hidden_states": packed_noisy_latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - "timestep": timesteps, - "guidance": guidance, - "pooled_projections": batch["add_text_embeds"].to( - device=self.accelerator.device, - dtype=self.config.base_weight_dtype, - ), - "encoder_hidden_states": batch["prompt_embeds"].to( - device=self.accelerator.device, - dtype=self.config.base_weight_dtype, - ), - "txt_ids": text_ids.to( - device=self.accelerator.device, - dtype=self.config.base_weight_dtype, - ), - "img_ids": img_ids, - "joint_attention_kwargs": None, - "return_dict": False, - } - if self.config.flux_attention_masked_training: - flux_transformer_kwargs["attention_mask"] = batch[ - "encoder_attention_mask" - ] - if flux_transformer_kwargs["attention_mask"] is None: - raise ValueError( - "No attention mask was discovered when attempting validation - this means you need to recreate your text embed cache." - ) - - model_pred = self.transformer(**flux_transformer_kwargs)[0] - - elif self.config.model_family == "sd3": - # Stable Diffusion 3 uses a MM-DiT model where the VAE-produced - # image embeds are passed in with the TE-produced text embeds. - model_pred = self.transformer( - hidden_states=noisy_latents.to( - device=self.accelerator.device, - dtype=self.config.base_weight_dtype, - ), - timestep=timesteps, - encoder_hidden_states=encoder_hidden_states.to( - device=self.accelerator.device, - dtype=self.config.base_weight_dtype, - ), - pooled_projections=add_text_embeds.to( - device=self.accelerator.device, - dtype=self.config.weight_dtype, - ), - return_dict=False, - )[0] - elif self.config.model_family == "pixart_sigma": - model_pred = self.transformer( - noisy_latents, + raise ValueError( + f"Cannot train parent-student networks on {self.config.lora_type} model. Only LyCORIS is supported." + ) + target = self.model_predict( + batch=batch, + latents=latents, + noisy_latents=noisy_latents, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=batch["encoder_attention_mask"], - timestep=timesteps, added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - model_pred = model_pred.chunk(2, dim=1)[0] - elif self.config.model_family == "smoldit": - first_latent_shape = noisy_latents.shape - height = first_latent_shape[1] * 8 - width = first_latent_shape[2] * 8 - grid_height = ( - height // 8 // self.transformer.config.patch_size - ) - grid_width = ( - width // 8 // self.transformer.config.patch_size - ) - base_size = 512 // 8 // self.transformer.config.patch_size - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size - ) - inputs = { - "hidden_states": noisy_latents, - "timestep": timesteps, - "encoder_hidden_states": encoder_hidden_states, - "encoder_attention_mask": batch[ - "encoder_attention_mask" - ], - "image_rotary_emb": get_2d_rotary_pos_embed( - self.transformer.inner_dim - // self.transformer.config.num_attention_heads, - grid_crops_coords, - (grid_height, grid_width), - ), - } - model_pred = self.transformer(**inputs).sample - elif self.unet is not None: - if self.config.model_family == "legacy": - # SD 1.5 or 2.x - model_pred = self.unet( - noisy_latents, - timesteps, - encoder_hidden_states, - ).sample - else: - # SDXL, Kolors, other default unet prediction. - model_pred = self.unet( - noisy_latents, - timesteps, - encoder_hidden_states, - added_cond_kwargs=added_cond_kwargs, - ).sample - else: - raise Exception( - "Unknown error occurred, no prediction could be made." - ) - # if we're quantising with quanto, we need to dequantise the result - # if "quanto" in self.config.base_model_precision: - # if hasattr(model_pred, "dequantize") and isinstance( - # model_pred, QTensor - # ): - # model_pred = model_pred.dequantize() - - if self.config.model_family == "flux": - model_pred = unpack_latents( - model_pred, - height=latents.shape[2] * 8, - width=latents.shape[3] * 8, - vae_scale_factor=16, + add_text_embeds=add_text_embeds, + timesteps=timesteps, ) + if self.config.lora_type.lower() == "lycoris": + logger.info( + "Attaching LyCORIS adapter for student prediction." + ) + self.accelerator._lycoris_wrapped_network.apply_to() - else: - # Dummy model prediction for debugging. - model_pred = torch.randn_like(noisy_latents) + training_logger.debug("Predicting noise residual.") + model_pred = self.model_predict( + batch=batch, + latents=latents, + noisy_latents=noisy_latents, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + add_text_embeds=add_text_embeds, + timesteps=timesteps, + ) # x-prediction requires that we now subtract the noise residual from the prediction to get the target sample. if ( @@ -2382,6 +2413,7 @@ def train(self): ): model_pred = model_pred - noise + parent_loss = None if self.config.flow_matching: loss = torch.mean( ((model_pred.float() - target.float()) ** 2).reshape( @@ -2436,6 +2468,9 @@ def train(self): * mse_loss_weights ).mean() + if is_regularisation_data: + parent_loss = loss + # Gather the losses across all processes for logging (if we use distributed training). avg_loss = self.accelerator.gather( loss.repeat(self.config.train_batch_size) @@ -2506,6 +2541,8 @@ def train(self): "learning_rate": self.lr, "epoch": epoch, } + if parent_loss is not None: + wandb_logs["regularisation_loss"] = parent_loss if self.config.model_family == "flux" and self.guidance_values_list: # avg the values guidance_values = torch.tensor(self.guidance_values_list).mean() @@ -2584,6 +2621,7 @@ def train(self): structured_data = { "state": self.state, "loss": round(self.train_loss, 4), + "parent_loss": parent_loss, "learning_rate": self.lr, "epoch": epoch, "final_epoch": self.config.num_train_epochs,