Skip to content

Commit

Permalink
reduce memory usage,
Browse files Browse the repository at this point in the history
 * and some cleanup
  • Loading branch information
wkpark committed Oct 11, 2024
1 parent a073c7a commit 211e13c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions scripts/model_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4728,12 +4728,10 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True):
key = k[len(prefix):]
if k in theta_0:
base_dict[key] = theta_0[k]
if prefix == "conditioner.":
shared.sd_model.conditioner.load_state_dict(base_dict, strict=False)
elif prefix == "cond_stage_model.":
shared.sd_model.cond_stage_model.load_state_dict(base_dict, strict=False)
elif prefix == "text_encoders.":
shared.sd_model.text_encoders.load_state_dict(base_dict, strict=False)

if prefix in ("conditioner.", "cond_stage_model.", "text_encoders."):
module = getattr(shared.sd_model, prefix[:-1])
module.load_state_dict(base_dict, strict=False)
print(" - \033[92mTextencoder(BASE) has been successfully updated\033[0m")
shared.state.textinfo = "Update Textencoder..."

Expand Down Expand Up @@ -4764,6 +4762,8 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True):
print(" - \033[93mReload full state_dict...\033[0m")
shared.state.textinfo = "Reload full state_dict..."
state_dict = shared.sd_model.state_dict().copy()
shared.sd_model.to("meta") # reduce memory usage
devices.torch_gc()
else:
state_dict = None

Expand Down

0 comments on commit 211e13c

Please sign in to comment.