diff --git a/src/airunner/widgets/base_widget.py b/src/airunner/widgets/base_widget.py index 6fb6d2d95..afe081352 100644 --- a/src/airunner/widgets/base_widget.py +++ b/src/airunner/widgets/base_widget.py @@ -1,6 +1,7 @@ import os from PyQt6 import QtGui from PyQt6.QtWidgets import QWidget +from airunner.aihandler.logger import Logger from airunner.utils import get_main_window from airunner.mediator_mixin import MediatorMixin @@ -42,6 +43,7 @@ def __init__(self, *args, **kwargs): MediatorMixin.__init__(self) SettingsMixin.__init__(self) super().__init__(*args, **kwargs) + self.logger = Logger(prefix=self.__class__.__name__) if self.widget_class_: self.ui = self.widget_class_() diff --git a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py index 72d85fcd0..0917d0007 100644 --- a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py +++ b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py @@ -5,11 +5,10 @@ from PIL import Image, ImageGrab from PIL.ImageQt import ImageQt -from PyQt6.QtCore import Qt, QPoint, QRect +from PyQt6.QtCore import Qt, QPoint, QRect, pyqtSlot from PyQt6.QtGui import QBrush, QColor, QPen, QPixmap from PyQt6.QtWidgets import QGraphicsPixmapItem from PyQt6 import QtWidgets, QtCore -from PyQt6.QtCore import pyqtSlot from PyQt6.QtWidgets import QGraphicsItemGroup, QGraphicsItem from airunner.workers.image_data_worker import ImageDataWorker @@ -24,33 +23,15 @@ class CanvasResizeWorker(Worker): - queue_type = "get_last_item" - last_cell_count = (0, 0) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register("canvas_resize_signal", self) + self.last_cell_count = (0, 0) - def __init__(self, prefix): - super().__init__(prefix=prefix) - self.buffer = None - self.register("application_settings_changed_signal", self) - self.register("set_current_layer_signal", self) - self.register("update_canvas_signal", self) - @pyqtSlot(object) - def on_update_canvas_signal(self, _ignore): - self.update() - - @pyqtSlot(object) - def on_set_current_layer_signal(self, args): - self.set_current_layer(args) - - def set_current_layer(self, args): - index, current_layer_index = args - item = self.ui.container.layout().itemAt(current_layer_index) - if item: - item.widget().frame.setStyleSheet(self.css("layer_normal_style")) - if self.ui.container: - item = self.ui.container.layout().itemAt(index) - if item: - item.widget().frame.setStyleSheet(self.css("layer_highlight_style")) + def on_canvas_resize_signal(self, data): + self.logger.info("Adding to queue") + self.add_to_queue(data) def handle_message(self, data:dict): settings = data["settings"] @@ -198,10 +179,31 @@ def __init__(self, *args, **kwargs): self.register("image_generated_signal", self) self.register("load_image_from_path", self) self.register("canvas_handle_layer_click_signal", self) + self.register("update_canvas_signal", self) + self.register("set_current_layer_signal", self) + self.register("application_settings_changed_signal", self) self.register_service("canvas_drag_pos", self.canvas_drag_pos) self.register_service("canvas_current_active_image", self.canvas_current_active_image) + @pyqtSlot(object) + def on_set_current_layer_signal(self, args): + self.set_current_layer(args) + + def set_current_layer(self, args): + index, current_layer_index = args + item = self.ui.container.layout().itemAt(current_layer_index) + if item: + item.widget().frame.setStyleSheet(self.css("layer_normal_style")) + if self.ui.container: + item = self.ui.container.layout().itemAt(index) + if item: + item.widget().frame.setStyleSheet(self.css("layer_highlight_style")) + + @pyqtSlot(object) + def on_update_canvas_signal(self, _ignore): + self.update() + def canvas_drag_pos(self): return self.drag_pos @@ -241,8 +243,8 @@ def on_image_generated_signal(self, image_data: dict): def on_CanvasResizeWorker_response_signal(self, line_data: tuple): draw_grid = self.settings["grid_settings"]["show_grid"] + print("on_CanvasResizeWorker_response_signal", draw_grid) if not draw_grid: - print("not draw_grid") return line = self.scene.addLine(*line_data) self.line_group.addToGroup(line) @@ -315,15 +317,20 @@ def current_image(self): return Image.fromqpixmap(pixmap) def handle_resize_canvas(self): + self.do_resize_canvas() + + def do_resize_canvas(self): if not self.view: self.logger.warning("view not found") return - self.canvas_resize_worker.add_to_queue(dict( + data = dict( settings=self.settings, view_size=self.view.viewport().size(), scene=self.scene, line_group=self.line_group - )) + ) + #self.emit("canvas_resize_signal", data) + self.canvas_resize_worker.add_to_queue(data) def window_resized(self, event): self.handle_resize_canvas() @@ -407,6 +414,8 @@ def wheelEvent(self, event): def on_application_settings_changed_signal(self): do_draw = False + + self.do_resize_canvas() grid_settings = self.settings["grid_settings"] for k,v in grid_settings.items(): diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 0e8f62083..c41972485 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -322,11 +322,9 @@ def do_generate(self, extra_options=None, seed=None, do_deterministic=False, ove # get the model from the database - print(model_data, self.generator_settings["model"]) name = model_data["name"] if "name" in model_data else self.generator_settings["model"] model = self.get_service("ai_model_by_name")(name) - print("MODEL:", model, name) # set the model data, first using model_data pulled from the override_data model_data = dict( name=model_data.get("name", model["name"]), diff --git a/src/airunner/widgets/model_manager/custom_widget.py b/src/airunner/widgets/model_manager/custom_widget.py index b6ada32a4..eaa3cd69c 100644 --- a/src/airunner/widgets/model_manager/custom_widget.py +++ b/src/airunner/widgets/model_manager/custom_widget.py @@ -1,9 +1,11 @@ import os from airunner.models.modeldata import ModelData +from airunner.service_locator import ServiceLocator from airunner.widgets.base_widget import BaseWidget from airunner.widgets.model_manager.model_widget import ModelWidget from airunner.widgets.model_manager.templates.custom_ui import Ui_custom_model_widget +from airunner.workers.worker import Worker from PyQt6 import QtWidgets from airunner.aihandler.logger import Logger @@ -11,24 +13,14 @@ logger = Logger(prefix="CustomModelWidget") -class CustomModelWidget(BaseWidget): - initialized = False - widget_class_ = Ui_custom_model_widget - model_widgets = [] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.show_items_in_scrollarea() - self.scan_for_models() - self.initialized = True - - def action_button_clicked_scan_for_models(self): +class ModelScannerWorker(Worker): + def handle_message(self, _message): self.scan_for_models() def scan_for_models(self): + self.logger.info("Scan for models") # look at model path and determine if we can import existing local models # first look at all files and folders inside of the model paths - base_model_path = self.path_settings["base_path"] txt2img_model_path = self.path_settings["txt2img_model_path"] depth2img_model_path = self.path_settings["depth2img_model_path"] pix2pix_model_path = self.path_settings["pix2pix_model_path"] @@ -38,6 +30,7 @@ def scan_for_models(self): llm_casuallm_model_path = self.path_settings["llm_casuallm_model_path"] llm_seq2seq_model_path = self.path_settings["llm_seq2seq_model_path"] diffusers_folders = ["scheduler", "text_encoder", "tokenizer", "unet", "vae"] + models = [] for key, model_path in { "txt2img": txt2img_model_path, "depth2img": depth2img_model_path, @@ -65,7 +58,7 @@ def scan_for_models(self): model.category = "stablediffusion" model.enabled = True model.pipeline_action = key - model.pipeline_class = self.get_service("get_pipeline_classname")( + model.pipeline_class = ServiceLocator.get("get_pipeline_classname")( model.pipeline_action, model.version, model.category ) @@ -85,25 +78,36 @@ def scan_for_models(self): model.name = entry.name if model: - self.save_model(model) + models.append(dict( + name=model.name, + path=model.path, + branch=model.branch, + version=model.version, + category=model.category, + pipeline_action=model.pipeline_action, + enabled=model.enabled, + is_default=False + )) - self.show_items_in_scrollarea() - self.update_generator_model_dropdown() - - def save_model(self, model): - self.emit("ai_model_save_or_update_signal", dict( - name=model.name, - path=model.path, - branch=model.branch, - version=model.version, - category=model.category, - pipeline_action=model.pipeline_action, - enabled=model.enabled, - is_default=False - )) - + self.emit("ai_models_save_or_update_signal", models) + + +class CustomModelWidget(BaseWidget): + initialized = False + widget_class_ = Ui_custom_model_widget + model_widgets = [] spacer = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.show_items_in_scrollarea() + self.initialized = True + self.model_scanner_worker = self.create_worker(ModelScannerWorker) + self.model_scanner_worker.add_to_queue("scan_for_models") + + def action_button_clicked_scan_for_models(self): + self.model_scanner_worker.add_to_queue("scan_for_models") + def show_items_in_scrollarea(self, search=None): if self.spacer: self.ui.scrollAreaWidgetContents.layout().removeItem(self.spacer) @@ -144,46 +148,6 @@ def show_items_in_scrollarea(self, search=None): self.spacer = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding) self.ui.scrollAreaWidgetContents.layout().addItem(self.spacer) - def models_changed(self, key, model, value): - model["enabled"] = True - self.emit("model_save_or_update_signal", model) - self.update_generator_model_dropdown() - - def handle_delete_model(self, model): - self.emit("ai_model_delete_signal", model) - self.show_items_in_scrollarea() - self.update_generator_model_dropdown() - - def update_generator_model_dropdown(self): - if self.initialized: - self.emit("refresh_available_models") - - def handle_edit_model(self, model, index): - print("edit button clicked", index) - self.toggle_model_form_frame(show=True) - - categories = self.get_service("ai_model_categories")() - self.ui.model_form.category.clear() - self.ui.model_form.category.addItems(categories) - self.ui.model_form.category.setCurrentText(model.category) - - actions = self.get_service("ai_model_pipeline_actions")() - self.ui.model_form.pipeline_action.clear() - self.ui.model_form.pipeline_action.addItems(actions) - self.ui.model_form.pipeline_action.setCurrentText(model.pipeline_action) - - self.ui.model_form.model_name.setText(model.name) - pipeline_class = self.get_service("get_pipeline_classname")( - model.pipeline_action, model.version, model.category) - self.ui.model_form.pipeline_class_line_edit.setText(pipeline_class) - self.ui.model_form.enabled.setChecked(True) - self.ui.model_form.path_line_edit.setText(model.path) - - versions = self.get_service("ai_model_versions")() - self.ui.model_form.versions.clear() - self.ui.model_form.versions.addItems(versions) - self.ui.model_form.versions.setCurrentText(model.version) - def mode_type_changed(self, val): print("mode_type_changed", val) diff --git a/src/airunner/widgets/model_manager/import_widget.py b/src/airunner/widgets/model_manager/import_widget.py index c2630f567..f7a17d322 100644 --- a/src/airunner/widgets/model_manager/import_widget.py +++ b/src/airunner/widgets/model_manager/import_widget.py @@ -132,7 +132,6 @@ def download_callback(self, current_size, total_size): def import_models(self): url = self.ui.import_url.text() - print("IMPORT MODELS") try: model_id = url.split("models/")[1] except IndexError: diff --git a/src/airunner/widgets/model_manager/model_manager_widget.py b/src/airunner/widgets/model_manager/model_manager_widget.py index 4e9a6dbc1..524cb75c6 100644 --- a/src/airunner/widgets/model_manager/model_manager_widget.py +++ b/src/airunner/widgets/model_manager/model_manager_widget.py @@ -98,3 +98,41 @@ def add_new_model(self): def tab_changed(self, val): print("tab_changed", val) + + def models_changed(self, key, model, value): + model["enabled"] = True + self.update_generator_model_dropdown() + + def handle_delete_model(self, model): + self.emit("ai_model_delete_signal", model) + self.show_items_in_scrollarea() + self.update_generator_model_dropdown() + + def update_generator_model_dropdown(self): + self.ui.generator_model_dropdown.clear() + self.ui.generator_model_dropdown.addItems(self.settings["ai_models"]) + + def handle_edit_model(self, model, index): + self.toggle_model_form_frame(show=True) + + categories = self.get_service("ai_model_categories")() + self.ui.model_form.category.clear() + self.ui.model_form.category.addItems(categories) + self.ui.model_form.category.setCurrentText(model.category) + + actions = self.get_service("ai_model_pipeline_actions")() + self.ui.model_form.pipeline_action.clear() + self.ui.model_form.pipeline_action.addItems(actions) + self.ui.model_form.pipeline_action.setCurrentText(model.pipeline_action) + + self.ui.model_form.model_name.setText(model.name) + pipeline_class = self.get_service("get_pipeline_classname")( + model.pipeline_action, model.version, model.category) + self.ui.model_form.pipeline_class_line_edit.setText(pipeline_class) + self.ui.model_form.enabled.setChecked(True) + self.ui.model_form.path_line_edit.setText(model.path) + + versions = self.get_service("ai_model_versions")() + self.ui.model_form.versions.clear() + self.ui.model_form.versions.addItems(versions) + self.ui.model_form.versions.setCurrentText(model.version) \ No newline at end of file diff --git a/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py b/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py index 792e0c491..45498b3ec 100644 --- a/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py +++ b/src/airunner/widgets/stablediffusion/stable_diffusion_settings_widget.py @@ -5,6 +5,10 @@ class StableDiffusionSettingsWidget(BaseWidget): widget_class_ = Ui_stable_diffusion_settings_widget + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register("models_changed_signal", self) + def showEvent(self, event): super().showEvent(event) steps = target_val = self.generator_settings["steps"] @@ -58,6 +62,7 @@ def handle_version_changed(self, val): self.load_models() def load_pipelines(self): + self.logger.info("load_pipelines") self.ui.pipeline.blockSignals(True) self.ui.pipeline.clear() pipeline_names = ["txt2img / img2img", "inpaint / outpaint", "depth2img", "pix2pix", "upscale", "superresolution", "txt2vid"] @@ -72,6 +77,7 @@ def load_pipelines(self): self.ui.pipeline.blockSignals(False) def load_versions(self): + self.logger.info("load_versions") self.ui.version.blockSignals(True) self.ui.version.clear() pipelines = self.get_service("get_pipelines")(category="stablediffusion") @@ -82,10 +88,17 @@ def load_versions(self): self.ui.version.setCurrentText(current_version) self.ui.version.blockSignals(False) + def on_models_changed_signal(self, _ignore): + self.load_pipelines() + self.load_versions() + self.load_models() + self.load_schedulers() + def clear_models(self): self.ui.model.clear() def load_models(self): + self.logger.info("load_models") self.ui.model.blockSignals(True) self.clear_models() @@ -110,6 +123,7 @@ def load_models(self): self.settings = settings def load_schedulers(self): + self.logger.info("load_schedulers") scheduler_names = [s["display_name"] for s in self.settings["schedulers"]] self.ui.scheduler.clear() self.ui.scheduler.addItems(scheduler_names) diff --git a/src/airunner/windows/main/ai_model_mixin.py b/src/airunner/windows/main/ai_model_mixin.py index 25a7755d7..368c08c72 100644 --- a/src/airunner/windows/main/ai_model_mixin.py +++ b/src/airunner/windows/main/ai_model_mixin.py @@ -1,7 +1,8 @@ -from airunner.service_locator import ServiceLocator - from PyQt6.QtCore import pyqtSlot +from airunner.service_locator import ServiceLocator +from airunner.data.bootstrap.model_bootstrap_data import model_bootstrap_data + class AIModelMixin: def __init__(self): @@ -23,9 +24,9 @@ def __init__(self): for service in services: ServiceLocator.register(service, getattr(self, service)) - self.register("ai_model_save_or_update_signal", self) + self.register("ai_models_save_or_update_signal", self) self.register("ai_model_delete_signal", self) - self.register("ai_model_create_signal", self) + self.register("ai_models_create_signal", self) def ai_model_get_by_filter(self, filter_dict): return [item for item in self.ai_models if all(item.get(k) == v for k, v in filter_dict.items())] @@ -36,6 +37,13 @@ def on_ai_model_create_signal(self, item): settings["ai_models"].append(item) self.settings = settings + @pyqtSlot(object) + def on_ai_models_create_signal(self, models): + settings = self.settings + settings["ai_models"] = models + self.settings = settings + self.emit("models_changed_signal", "models") + def ai_model_update(self, item): settings = self.settings for i, existing_item in enumerate(self.ai_models): @@ -63,13 +71,24 @@ def ai_model_get_disabled_default(self): return [model for model in self.ai_models if model["is_default"] == True and model["enabled"] == False] @pyqtSlot(object) - def on_ai_model_save_or_update_signal(self, model_data): - # find the model by name and path, if it exists, update it, otherwise insert it - existing_model = self.ai_model_get_by_filter({"name": model_data["name"], "path": model_data["path"]}) - if existing_model: - self.ai_model_update(model_data) - else: - self.emit("ai_model_create_signal", model_data) + def on_ai_models_save_or_update_signal(self, new_models): + settings = self.settings + default_models = model_bootstrap_data + existing_models = settings["ai_models"] + + # Convert list of models to dictionary with model name as key + model_dict = {model['name']: model for model in default_models} + + # Update the dictionary with existing models + model_dict.update({model['name']: model for model in existing_models}) + + # Update the dictionary with new models + model_dict.update({model['name']: model for model in new_models}) + + # Convert back to list + merged_models = list(model_dict.values()) + + self.emit("ai_models_create_signal", merged_models) def ai_model_paths(self, model_type=None, pipeline_action=None): models = self.ai_models diff --git a/src/airunner/windows/main/settings_mixin.py b/src/airunner/windows/main/settings_mixin.py index 26e8a6905..4721703f9 100644 --- a/src/airunner/windows/main/settings_mixin.py +++ b/src/airunner/windows/main/settings_mixin.py @@ -391,6 +391,7 @@ def set_settings(self, val): self.emit("application_settings_changed_signal") def on_reset_settings_signal(self): + self.logger.info("Resetting settings") self.application_settings.clear() self.application_settings.sync() self.set_settings(self.get_settings()) diff --git a/src/airunner/workers/canvas_resize_worker.py b/src/airunner/workers/canvas_resize_worker.py new file mode 100644 index 000000000..f3478008d --- /dev/null +++ b/src/airunner/workers/canvas_resize_worker.py @@ -0,0 +1,65 @@ +from PyQt6.QtCore import Qt, pyqtSlot +from PyQt6.QtGui import QBrush, QColor, QPen +from airunner.workers.worker import Worker + + +class CanvasResizeWorker(Worker): + queue_type = "get_last_item" + last_cell_count = (0, 0) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register("canvas_resize_signal", self) + + @pyqtSlot(object) + def on_canvas_resize_signal(self, data): + self.add_to_queue(data) + + def handle_message(self, data): + settings = data["settings"] + view_size = data["view_size"] + + cell_size = settings["grid_settings"]["cell_size"] + line_color = settings["grid_settings"]["line_color"] + line_width = settings["grid_settings"]["line_width"] + + width_cells = view_size.width() // cell_size + height_cells = view_size.height() // cell_size + + # Check if the number of cells has changed + if (width_cells, height_cells) == self.last_cell_count: + return + self.last_cell_count = (width_cells, height_cells) + + pen = QPen( + QBrush(QColor(line_color)), + line_width, + Qt.PenStyle.SolidLine + ) + + lines_data = [] + + # vertical lines + h = view_size.height() + abs(settings["canvas_settings"]["pos_y"]) % cell_size + y = 0 + x = settings["canvas_settings"]["pos_x"] % cell_size + for i in range(width_cells): + line_data = (x, y, x, h, pen) + lines_data.append(line_data) + x += cell_size + + # horizontal lines + w = view_size.width() + abs(settings["canvas_settings"]["pos_x"]) % cell_size + x = 0 + y = settings["canvas_settings"]["pos_y"] % cell_size + for i in range(height_cells): + line_data = (x, y, w, y, pen) + lines_data.append(line_data) + y += cell_size + + self.emit("canvas_clear_lines_signal") + + for line_data in lines_data: + self.emit("CanvasResizeWorker_response_signal", line_data) + + self.emit("canvas_do_draw_signal") \ No newline at end of file diff --git a/src/airunner/workers/worker.py b/src/airunner/workers/worker.py index 27b3aff2b..cb868ab96 100644 --- a/src/airunner/workers/worker.py +++ b/src/airunner/workers/worker.py @@ -10,9 +10,10 @@ class Worker(QObject, MediatorMixin, SettingsMixin): queue_type = "get_next_item" finished = pyqtSignal() + prefix = "Worker" - def __init__(self, prefix="Worker"): - self.prefix = prefix + def __init__(self, prefix=None): + self.prefix = prefix or self.__class__.__name__ super().__init__() MediatorMixin.__init__(self) SettingsMixin.__init__(self) @@ -42,10 +43,7 @@ def start(self): # if self.queue has more than one item, scrap everything other than the last item that # was added to the queue msg = self.get_item_from_queue() - if msg is not None: - self.handle_message(msg) - else: - self.logger.warning("No message") + self.handle_message(msg) except queue.Empty: msg = None if self.paused: