From 211e13c7122348fbaa8f2fd883e575f31511b14c Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 11 Oct 2024 19:57:00 +0900 Subject: [PATCH] reduce memory usage, * and some cleanup --- scripts/model_mixer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 51a904e..95b27f0 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -4728,12 +4728,10 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): key = k[len(prefix):] if k in theta_0: base_dict[key] = theta_0[k] - if prefix == "conditioner.": - shared.sd_model.conditioner.load_state_dict(base_dict, strict=False) - elif prefix == "cond_stage_model.": - shared.sd_model.cond_stage_model.load_state_dict(base_dict, strict=False) - elif prefix == "text_encoders.": - shared.sd_model.text_encoders.load_state_dict(base_dict, strict=False) + + if prefix in ("conditioner.", "cond_stage_model.", "text_encoders."): + module = getattr(shared.sd_model, prefix[:-1]) + module.load_state_dict(base_dict, strict=False) print(" - \033[92mTextencoder(BASE) has been successfully updated\033[0m") shared.state.textinfo = "Update Textencoder..." @@ -4764,6 +4762,8 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): print(" - \033[93mReload full state_dict...\033[0m") shared.state.textinfo = "Reload full state_dict..." state_dict = shared.sd_model.state_dict().copy() + shared.sd_model.to("meta") # reduce memory usage + devices.torch_gc() else: state_dict = None