Skip to content

Commit

Permalink
Add support for specifying rank for each layer in FLUX.1
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 14, 2024
1 parent 2d8ee3c commit c9ff4de
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 7 deletions.
61 changes: 61 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 14, 2024:
- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details.
- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details.

Sep 11, 2024:
Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev!

Expand Down Expand Up @@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model)
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
- [FLUX.1 OFT training](#flux1-oft-training)
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
Expand Down Expand Up @@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/

The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.

#### Specify rank for each layer in FLUX.1

You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.

When network_args is not specified, the default value (`network_dim`) is applied, same as before.

|network_args|target layer|
|---|---|
|img_attn_dim|img_attn in DoubleStreamBlock|
|txt_attn_dim|txt_attn in DoubleStreamBlock|
|img_mlp_dim|img_mlp in DoubleStreamBlock|
|txt_mlp_dim|txt_mlp in DoubleStreamBlock|
|img_mod_dim|img_mod in DoubleStreamBlock|
|txt_mod_dim|txt_mod in DoubleStreamBlock|
|single_dim|linear1 and linear2 in SingleStreamBlock|
|single_mod_dim|modulation in SingleStreamBlock|

example:
```
--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2"
"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2"
```

You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list.

example:
```
--network_args "in_dims=[4,2,2,2,4]"
```

Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`.

If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.

### FLUX.1 OFT training

You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.

- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`.
- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc.
- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it.
- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`.
- `--network_args` specifies the hyperparameters of OFT. The following are valid:
- Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention.

Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`).

Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1.

```
--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3
--network_args "enable_all_linear=True" --learning_rate 1e-5
```

The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer.

### Inference for FLUX.1 with LoRA model

The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
Expand Down
107 changes: 100 additions & 7 deletions networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,44 @@ def create_network(
else:
conv_alpha = float(conv_alpha)

# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
img_attn_dim = kwargs.get("img_attn_dim", None)
txt_attn_dim = kwargs.get("txt_attn_dim", None)
img_mlp_dim = kwargs.get("img_mlp_dim", None)
txt_mlp_dim = kwargs.get("txt_mlp_dim", None)
img_mod_dim = kwargs.get("img_mod_dim", None)
txt_mod_dim = kwargs.get("txt_mod_dim", None)
single_dim = kwargs.get("single_dim", None) # SingleStreamBlock
single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock
if img_attn_dim is not None:
img_attn_dim = int(img_attn_dim)
if txt_attn_dim is not None:
txt_attn_dim = int(txt_attn_dim)
if img_mlp_dim is not None:
img_mlp_dim = int(img_mlp_dim)
if txt_mlp_dim is not None:
txt_mlp_dim = int(txt_mlp_dim)
if img_mod_dim is not None:
img_mod_dim = int(img_mod_dim)
if txt_mod_dim is not None:
txt_mod_dim = int(txt_mod_dim)
if single_dim is not None:
single_dim = int(single_dim)
if single_mod_dim is not None:
single_mod_dim = int(single_mod_dim)
type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim]
if all([d is None for d in type_dims]):
type_dims = None

# in_dims [img, time, vector, guidance, txt]
in_dims = kwargs.get("in_dims", None)
if in_dims is not None:
in_dims = in_dims.strip()
if in_dims.startswith("[") and in_dims.endswith("]"):
in_dims = in_dims[1:-1]
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"

# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
Expand All @@ -339,6 +377,11 @@ def create_network(
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False

# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False

# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
Expand All @@ -354,7 +397,9 @@ def create_network(
train_blocks=train_blocks,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
varbose=True,
type_dims=type_dims,
in_dims=in_dims,
verbose=verbose,
)

loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
Expand Down Expand Up @@ -462,7 +507,9 @@ def __init__(
train_blocks: Optional[str] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
varbose: Optional[bool] = False,
type_dims: Optional[List[int]] = None,
in_dims: Optional[List[int]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
Expand All @@ -478,12 +525,17 @@ def __init__(
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl

self.type_dims = type_dims
self.in_dims = in_dims

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.in_dims = [0] * 5 # create in_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
Expand All @@ -502,7 +554,12 @@ def __init__(

# create module instances
def create_modules(
is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str]
is_flux: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_FLUX
Expand All @@ -513,16 +570,22 @@ def create_modules(
loras = []
skipped = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules

for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)

if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")

if filter is not None and not filter in lora_name:
continue

dim = None
alpha = None

Expand All @@ -534,8 +597,25 @@ def create_modules(
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha

if type_dims is not None:
identifier = [
("img_attn",),
("txt_attn",),
("img_mlp",),
("txt_mlp",),
("img_mod",),
("txt_mod",),
("single_blocks", "linear"),
("modulation",),
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d
break

elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
Expand Down Expand Up @@ -566,6 +646,9 @@ def create_modules(
split_dims=split_dims,
)
loras.append(lora)

if target_replace_modules is None:
break # all modules are searched
return loras, skipped

# create LoRA for text encoder
Expand Down Expand Up @@ -594,10 +677,20 @@ def create_modules(

self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)

# img, time, vector, guidance, txt
if self.in_dims:
for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
self.unet_loras.extend(loras)

logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")

skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
Expand Down

0 comments on commit c9ff4de

Please sign in to comment.