Skip to content

Commit

Permalink
prepare override_uses correctly.
Browse files Browse the repository at this point in the history
 * check best_para
  • Loading branch information
wkpark committed Dec 27, 2023
1 parent 01f5284 commit 2d8f5d5
Showing 1 changed file with 35 additions and 21 deletions.
56 changes: 35 additions & 21 deletions sd_modelmixer/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,20 @@ def hyper_score(localargs):
if variable_models is not None and len(variable_models) == 0:
variable_models = None

# setup override uses
override_uses = uses.copy()

k = 0
for i in range(len(uses)):
if uses[i] is not True:
continue

# is this model in the variable_models? e.g.) chr(0+66) == B
if variable_models is not None and chr(i+66) not in variable_models:
k += 1
continue
if variable_models is not None:
if chr(i+66) not in variable_models:
override_uses[i] = False
k += 1
continue

name = f"model_{chr(i + 98)}"
weight = weights[k]
Expand All @@ -352,7 +357,7 @@ def hyper_score(localargs):
search_space[f"{name}.{_BLOCKS[j]}"] = [*np.round(np.linspace(lower, upper, 5), 8)]
k += 1

#print(" - search_space keys =", search_space.keys())
print(" - search_space keys =", search_space.keys())

# setup warm_start
if warm_start:
Expand Down Expand Up @@ -420,7 +425,7 @@ def hyper_score(localargs):
pass_through = {
"tunables": [*search_space.keys()],
"weights": weights,
"uses": uses,
"uses": override_uses,
"classifier": classifier,
"payload_path": payload_path,
"tally_type": tally_type,
Expand Down Expand Up @@ -470,37 +475,46 @@ def hyper_score(localargs):

# run main optimizer
shared.state.begin(job="modelmixer-auto-merger")
best_para = None
try:
hyper.run(search_time*60)
best_para = hyper.best_para(hyper_score)
except Exception as e:
print(f"Error: {e}")
finally:
shared.state.end()

best_para = hyper.best_para(hyper_score)
best_weights = para_to_weights(best_para, isxl)
print(" - Best weights para = ", best_weights)
if best_para is not None:
best_weights = para_to_weights(best_para, isxl)
print(" - Best weights para = ", best_weights, override_uses)

# setup override weights. will be replaced with mm_weights
shared.modelmixer_overrides = {"weights": best_weights, "uses": uses}
# setup override weights. will be replaced with mm_weights

# generate image with the optimized parameter
ret = None
ret = txt2img.txt2img(*txt2img_args)
if ret and ret[0] is not None:
gallery = ret[0] # gallery
shared.modelmixer_overrides = {"weights": best_weights, "uses": override_uses}

score = score_func(classifier, gallery[0], prompt)
print("Result score =", score)
# generate image with the optimized parameter
ret = None
ret = txt2img.txt2img(*txt2img_args)
if ret and ret[0] is not None:
gallery = ret[0] # gallery

score = score_func(classifier, gallery[0], prompt)
print("Result score =", score)

shared._memory_warm_start = hyper.search_data(hyper_score)
shared._memory_warm_hash = warm_hash
shared._memory_warm_start = hyper.search_data(hyper_score)
shared._memory_warm_hash = warm_hash
msg = "merge completed."
else:
shared._memory_warm_start = None
shared._memory_warm_hash = None
msg = "Failed to call hyper.run()"

# search data save
#timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%I-%M%p-%S")
#collector = SearchDataCollector(os.path.join(folder_path, f"{model_O}-{pass}-{timestamp}.csv"))
#collector.save(hyper.search_data(hyper_score, times=True))

delattr(shared, "modelmixer_overrides")
if hasattr(shared, "modelmixer_overrides"):
delattr(shared, "modelmixer_overrides")

return "merge completed."
return msg

0 comments on commit 2d8f5d5

Please sign in to comment.