From 2d8f5d520de3119d36be0883ec7f78d4a3400201 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 27 Dec 2023 20:19:55 +0900 Subject: [PATCH] prepare override_uses correctly. * check best_para --- sd_modelmixer/hyper.py | 56 ++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/sd_modelmixer/hyper.py b/sd_modelmixer/hyper.py index 5b64efe..70eca2f 100644 --- a/sd_modelmixer/hyper.py +++ b/sd_modelmixer/hyper.py @@ -327,15 +327,20 @@ def hyper_score(localargs): if variable_models is not None and len(variable_models) == 0: variable_models = None + # setup override uses + override_uses = uses.copy() + k = 0 for i in range(len(uses)): if uses[i] is not True: continue # is this model in the variable_models? e.g.) chr(0+66) == B - if variable_models is not None and chr(i+66) not in variable_models: - k += 1 - continue + if variable_models is not None: + if chr(i+66) not in variable_models: + override_uses[i] = False + k += 1 + continue name = f"model_{chr(i + 98)}" weight = weights[k] @@ -352,7 +357,7 @@ def hyper_score(localargs): search_space[f"{name}.{_BLOCKS[j]}"] = [*np.round(np.linspace(lower, upper, 5), 8)] k += 1 - #print(" - search_space keys =", search_space.keys()) + print(" - search_space keys =", search_space.keys()) # setup warm_start if warm_start: @@ -420,7 +425,7 @@ def hyper_score(localargs): pass_through = { "tunables": [*search_space.keys()], "weights": weights, - "uses": uses, + "uses": override_uses, "classifier": classifier, "payload_path": payload_path, "tally_type": tally_type, @@ -470,37 +475,46 @@ def hyper_score(localargs): # run main optimizer shared.state.begin(job="modelmixer-auto-merger") + best_para = None try: hyper.run(search_time*60) + best_para = hyper.best_para(hyper_score) except Exception as e: print(f"Error: {e}") finally: shared.state.end() - best_para = hyper.best_para(hyper_score) - best_weights = para_to_weights(best_para, isxl) - print(" - Best weights para = ", best_weights) + if best_para is not None: + best_weights = para_to_weights(best_para, isxl) + print(" - Best weights para = ", best_weights, override_uses) - # setup override weights. will be replaced with mm_weights - shared.modelmixer_overrides = {"weights": best_weights, "uses": uses} + # setup override weights. will be replaced with mm_weights - # generate image with the optimized parameter - ret = None - ret = txt2img.txt2img(*txt2img_args) - if ret and ret[0] is not None: - gallery = ret[0] # gallery + shared.modelmixer_overrides = {"weights": best_weights, "uses": override_uses} - score = score_func(classifier, gallery[0], prompt) - print("Result score =", score) + # generate image with the optimized parameter + ret = None + ret = txt2img.txt2img(*txt2img_args) + if ret and ret[0] is not None: + gallery = ret[0] # gallery + + score = score_func(classifier, gallery[0], prompt) + print("Result score =", score) - shared._memory_warm_start = hyper.search_data(hyper_score) - shared._memory_warm_hash = warm_hash + shared._memory_warm_start = hyper.search_data(hyper_score) + shared._memory_warm_hash = warm_hash + msg = "merge completed." + else: + shared._memory_warm_start = None + shared._memory_warm_hash = None + msg = "Failed to call hyper.run()" # search data save #timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%I-%M%p-%S") #collector = SearchDataCollector(os.path.join(folder_path, f"{model_O}-{pass}-{timestamp}.csv")) #collector.save(hyper.search_data(hyper_score, times=True)) - delattr(shared, "modelmixer_overrides") + if hasattr(shared, "modelmixer_overrides"): + delattr(shared, "modelmixer_overrides") - return "merge completed." + return msg