Skip to content

Commit

Permalink
add missing optimizer.eval() call for schedulefree
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Apr 9, 2024
1 parent 14f215b commit 317d335
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def train(self, args):
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)

optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
schedulefree = "schedulefree" in args.optimizer_type.lower()

# dataloaderを準備する
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
Expand Down Expand Up @@ -993,6 +994,14 @@ def remove_model(old_ckpt_name):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids):
# Update embeddings map (for saving)
with torch.no_grad():
for emb_name in embeddings_map.keys():
for i, (t_enc, emb_token_ids) in enumerate(zip(text_encoders, embedding_to_token_ids[emb_name])):
updated_embs = accelerator.unwrap_model(t_enc).get_input_embeddings().weight[emb_token_ids].data.detach().clone()
embeddings_map[emb_name][i] = updated_embs

# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

Expand All @@ -1003,7 +1012,7 @@ def remove_model(old_ckpt_name):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

if "schedulefree" in args.optimizer_type.lower():
if schedulefree:
optimizer.optimizer.train()

if args.continue_inversion:
Expand Down Expand Up @@ -1147,10 +1156,9 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Let's make sure we don't update any embedding weights besides the added pivots
# normalize embeddings
if args.continue_inversion:
with torch.no_grad():
# normalize embeddings
if args.clip_ti_decay:
for t_enc, index_updates in zip(text_encoders, index_updates_list):
pre_norm = (
Expand All @@ -1168,20 +1176,6 @@ def remove_model(old_ckpt_name):
pre_norm + lambda_ * (0.4 - pre_norm)
)

# # Let's make sure we don't update any embedding weights besides the newly added token
# for t_enc, orig_embeds_params, index_no_updates in zip(
# text_encoders, orig_embeds_params_list, index_no_updates_list
# ):
# input_embeddings_weight = accelerator.unwrap_model(t_enc).get_input_embeddings().weight
# input_embeddings_weight[index_no_updates] = orig_embeds_params[index_no_updates]

# Update embeddings map (for saving)
# TODO: this is not optimal, might need to be refactored
for emb_name in embeddings_map.keys():
for i, (t_enc, emb_token_ids) in enumerate(zip(text_encoders, embedding_to_token_ids[emb_name])):
updated_embs = accelerator.unwrap_model(t_enc).get_input_embeddings().weight[emb_token_ids].data.detach().clone()
embeddings_map[emb_name][i] = updated_embs

if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
Expand All @@ -1201,6 +1195,11 @@ def remove_model(old_ckpt_name):
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if schedulefree:
optimizer.optimizer.eval()
if args.continue_inversion:
update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids)

ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch, embeddings_map=embeddings_map)

Expand Down Expand Up @@ -1238,6 +1237,11 @@ def remove_model(old_ckpt_name):
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
if schedulefree:
optimizer.optimizer.eval()
if args.continue_inversion:
update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids)

ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1, embeddings_map=embeddings_map)

Expand Down Expand Up @@ -1265,6 +1269,11 @@ def remove_model(old_ckpt_name):
train_util.save_state_on_train_end(args, accelerator)

if is_main_process:
if schedulefree:
optimizer.optimizer.eval()
if args.continue_inversion:
update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids)

ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, embeddings_map=embeddings_map, force_sync_upload=True)

Expand Down

0 comments on commit 317d335

Please sign in to comment.