Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement regularisation dataset parent-student loss for LyCORIS training #1050

Merged
merged 3 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ Images are not resized before cropping **unless** `maximum_image_size` and `targ

> ℹ️ This value behaves differently to the same option in Kohya's scripts, where a value of 1 means no repeats. **For SimpleTuner, a value of 0 means no repeats**. Subtract one from your Kohya config value to obtain the equivalent for SimpleTuner, hence a value of **9** resulting from the calculation `(dataset_length + repeats * dataset_length)` .

### `is_regularisation_data`

- Also may be spelt `is_regularization_data`
- Enables parent-teacher training for LyCORIS adapters so that the prediction target prefers the base model's result for a given dataset.
- Standard LoRA are not currently supported.

### `vae_cache_clear_each_epoch`

- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets, but combined with `crop_style=random` and/or `crop_aspect=random` you'll want this enabled to ensure you sample a full range of crops from each image.
Expand Down
10 changes: 7 additions & 3 deletions documentation/DREAMBOOTH.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Since that time, the idea has evolved and debated, with an opposing camp decidin

The model contains something called a "prior" which could, in theory, be preserved during Dreambooth training. In experiments with Stable Diffusion however, it didn't seem to help - the model just overfits on its own knowledge.

> 🔴 Prior preservation loss is not supported in SimpleTuner, all regularisation data is treated as if it were usual training data.
> 🟢 ([#1031](https://github.com/bghira/SimpleTuner/issues/1031)) Prior preservation loss is supported in SimpleTuner when training LyCORIS adapters by setting `is_regularisation_data` on that dataset.

## Setup

Expand Down Expand Up @@ -105,7 +105,8 @@ Inside our dataloader config `multidatabackend-dreambooth.json`, it will look so
"repeats": 0,
"resolution": 512,
"resolution_type": "pixel_area",
"minimum_image_size": 192
"minimum_image_size": 192,
"is_regularisation_data": true
},
{
"id": "regularisation-data-1024px",
Expand All @@ -117,7 +118,8 @@ Inside our dataloader config `multidatabackend-dreambooth.json`, it will look so
"repeats": 0,
"resolution": 1024,
"resolution_type": "pixel_area",
"minimum_image_size": 768
"minimum_image_size": 768,
"is_regularisation_data": true
},
{
"id": "textembeds",
Expand All @@ -143,6 +145,8 @@ For a regularisation dataset:
- If your Regularisation set has 1000 images, and you have 10 images in your training set, you'd want a repeats value of at least 100 to get fast results
- `minimum_image_size` has been increased to ensure we don't introduce too many low-quality artifacts
- Similarly, using more descriptive captions may help avoid forgetting. Switching from `instanceprompt` to `textfile` or other strategies will require creating `.txt` files for each image.
- When `is_regularisation_data` (or 🇺🇸 `is_regularization_data` with a z, for the American users) is set, the data from this set will be fed into the base model to obtain a prediction that can be used as a loss target for the student LyCORIS model.
- Note, currently this only functions on a LyCORIS adapter.

## Selecting an instance prompt

Expand Down
4 changes: 3 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ Create a `--data_backend_config` (`config/multidatabackend.json`) document conta
"skip_file_discovery": "",
"caption_strategy": "filename",
"metadata_backend": "discovery",
"repeats": 0
"repeats": 0,
"is_regularisation_data": true
},
{
"id": "dreambooth-subject",
Expand Down Expand Up @@ -469,6 +470,7 @@ We can partially reintroduce distillation to a de-distilled model by continuing
#### LoKr (--lora_type=lycoris)
- Higher learning rates are better for LoKr (`1e-3` with AdamW, `2e-4` with Lion)
- Other algo need more exploration.
- Setting `is_regularisation_data` on such datasets may help preserve / prevent bleed.

### Image artifacts
Flux will immediately absorb bad image artifacts. It's just how it is - a final training run on just high quality data may be required to fix it at the end.
Expand Down
4 changes: 4 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,13 +835,17 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
)

use_captions = True
is_regularisation_data = backend.get(
"is_regularisation_data", backend.get("is_regularization_data", False)
)
if "only_instance_prompt" in backend and backend["only_instance_prompt"]:
use_captions = False
elif args.only_instance_prompt:
use_captions = False
init_backend["train_dataset"] = MultiAspectDataset(
id=init_backend["id"],
datasets=[init_backend["metadata_backend"]],
is_regularisation_data=is_regularisation_data,
)

if "deepfloyd" in args.model_type:
Expand Down
5 changes: 4 additions & 1 deletion helpers/multiaspect/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def __init__(
self,
id: str,
datasets: list,
print_names=False,
print_names: bool = False,
is_regularisation_data: bool = False,
):
self.id = id
self.datasets = datasets
self.print_names = print_names
self.is_regularisation_data = is_regularisation_data

def __len__(self):
# Sum the length of all data backends:
Expand All @@ -34,6 +36,7 @@ def __getitem__(self, image_tuple):
output_data = {
"training_samples": [],
"conditioning_samples": [],
"is_regularisation_data": self.is_regularisation_data,
}
first_aspect_ratio = None
for sample in image_tuple:
Expand Down
2 changes: 2 additions & 0 deletions helpers/training/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def collate_fn(batch):
batch = batch[0]
examples = batch["training_samples"]
conditioning_examples = batch["conditioning_samples"]
is_regularisation_data = batch.get("is_regularisation_data", False)
if StateTracker.get_args().controlnet and len(examples) != len(
conditioning_examples
):
Expand Down Expand Up @@ -475,4 +476,5 @@ def collate_fn(batch):
"batch_luminance": batch_luminance,
"conditioning_pixel_values": conditioning_latents,
"encoder_attention_mask": attn_mask,
"is_regularisation_data": is_regularisation_data,
}
10 changes: 10 additions & 0 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,13 @@ def safety_check(args, accelerator):
f"Your GPU has {total_memory_gb}GB of memory. The SOAP optimiser requires a GPU with at least 24G of memory."
)
sys.exit(1)

if (
args.model_type != "lora"
and args.base_model_precision != "no_change"
and not args.i_know_what_i_am_doing
):
logger.error(
f"{args.model_type} tuning is not compatible with quantisation. Please set --base_model_precision to 'no_change' or train LyCORIS/LoRA."
)
sys.exit(1)
99 changes: 51 additions & 48 deletions helpers/training/quantisation/quanto_workarounds.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch

import optimum

if torch.cuda.is_available():
# the marlin fp8 kernel needs some help with dtype casting for some reason
# see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201
import optimum
from optimum.quanto.library.extensions.cuda import ext as quanto_ext

# Save the original operator
Expand Down Expand Up @@ -62,52 +63,54 @@ def forward(ctx, input, other, bias):

tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction

class WeightQBytesLinearFunction(
optimum.quanto.tensor.function.QuantizedLinearFunction
):
@staticmethod
def forward(ctx, input, other, bias=None):
ctx.save_for_backward(input, other)
if isinstance(input, optimum.quanto.tensor.QBytesTensor):
output = torch.ops.quanto.qbytes_mm(
input._data, other._data, input._scale * other._scale
)
else:
in_features = input.shape[-1]
out_features = other.shape[0]
output_shape = input.shape[:-1] + (out_features,)
output = torch.ops.quanto.qbytes_mm(
input.reshape(-1, in_features), other._data, other._scale
)
output = output.view(output_shape)
if bias is not None:
output = output + bias
return output

optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = (
WeightQBytesLinearFunction
)

def reshape_qlf_backward(ctx, gO):
# another one where we need .reshape instead of .view
input_gO = other_gO = bias_gO = None
input, other = ctx.saved_tensors
out_features, in_features = other.shape
if ctx.needs_input_grad[0]:
# grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B
input_gO = torch.matmul(gO, other)
if ctx.needs_input_grad[1]:
# grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A
other_gO = torch.matmul(
gO.reshape(-1, out_features).t(),
input.to(g0.dtype).reshape(-1, in_features),
class WeightQBytesLinearFunction(
optimum.quanto.tensor.function.QuantizedLinearFunction
):
@staticmethod
def forward(ctx, input, other, bias=None):
ctx.save_for_backward(input, other)
if isinstance(input, optimum.quanto.tensor.QBytesTensor):
output = torch.ops.quanto.qbytes_mm(
input._data, other._data, input._scale * other._scale
)
else:
in_features = input.shape[-1]
out_features = other.shape[0]
output_shape = input.shape[:-1] + (out_features,)
output = torch.ops.quanto.qbytes_mm(
input.reshape(-1, in_features), other._data, other._scale
)
if ctx.needs_input_grad[2]:
# Bias gradient is the sum on all dimensions but the last one
dim = tuple(range(gO.ndim - 1))
bias_gO = gO.sum(dim)
return input_gO, other_gO, bias_gO

optimum.quanto.tensor.function.QuantizedLinearFunction.backward = (
reshape_qlf_backward
)
output = output.view(output_shape)
if bias is not None:
output = output + bias
return output


optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = (
WeightQBytesLinearFunction
)


def reshape_qlf_backward(ctx, gO):
# another one where we need .reshape instead of .view
input_gO = other_gO = bias_gO = None
input, other = ctx.saved_tensors
out_features, in_features = other.shape
if ctx.needs_input_grad[0]:
# grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B
input_gO = torch.matmul(gO, other)
if ctx.needs_input_grad[1]:
# grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A
other_gO = torch.matmul(
gO.reshape(-1, out_features).t(),
input.to(g0.dtype).reshape(-1, in_features),
)
if ctx.needs_input_grad[2]:
# Bias gradient is the sum on all dimensions but the last one
dim = tuple(range(gO.ndim - 1))
bias_gO = gO.sum(dim)
return input_gO, other_gO, bias_gO


optimum.quanto.tensor.function.QuantizedLinearFunction.backward = reshape_qlf_backward
Loading
Loading