diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 3c8b75a08808..53ee0f89e280 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -837,11 +837,6 @@ def main(args): assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) - if args.train_norm_layers: - for name, param in flux_transformer.named_parameters(): - if any(k in name for k in NORM_LAYER_PREFIXES): - param.requires_grad = True - if args.lora_layers is not None: if args.lora_layers != "all-linear": target_modules = [layer.strip() for layer in args.lora_layers.split(",")] @@ -879,6 +874,11 @@ def main(args): ) flux_transformer.add_adapter(transformer_lora_config) + if args.train_norm_layers: + for name, param in flux_transformer.named_parameters(): + if any(k in name for k in NORM_LAYER_PREFIXES): + param.requires_grad = True + def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model