Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for canvas updates and model loading #389

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/airunner/widgets/base_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from PyQt6 import QtGui
from PyQt6.QtWidgets import QWidget
from airunner.aihandler.logger import Logger

from airunner.utils import get_main_window
from airunner.mediator_mixin import MediatorMixin
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self, *args, **kwargs):
MediatorMixin.__init__(self)
SettingsMixin.__init__(self)
super().__init__(*args, **kwargs)
self.logger = Logger(prefix=self.__class__.__name__)

if self.widget_class_:
self.ui = self.widget_class_()
Expand Down
69 changes: 39 additions & 30 deletions src/airunner/widgets/canvas_plus/canvas_plus_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from PIL import Image, ImageGrab
from PIL.ImageQt import ImageQt
from PyQt6.QtCore import Qt, QPoint, QRect
from PyQt6.QtCore import Qt, QPoint, QRect, pyqtSlot
from PyQt6.QtGui import QBrush, QColor, QPen, QPixmap
from PyQt6.QtWidgets import QGraphicsPixmapItem
from PyQt6 import QtWidgets, QtCore
from PyQt6.QtCore import pyqtSlot
from PyQt6.QtWidgets import QGraphicsItemGroup, QGraphicsItem

from airunner.workers.image_data_worker import ImageDataWorker
Expand All @@ -24,33 +23,15 @@


class CanvasResizeWorker(Worker):
queue_type = "get_last_item"
last_cell_count = (0, 0)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register("canvas_resize_signal", self)
self.last_cell_count = (0, 0)

def __init__(self, prefix):
super().__init__(prefix=prefix)
self.buffer = None
self.register("application_settings_changed_signal", self)
self.register("set_current_layer_signal", self)
self.register("update_canvas_signal", self)

@pyqtSlot(object)
def on_update_canvas_signal(self, _ignore):
self.update()

@pyqtSlot(object)
def on_set_current_layer_signal(self, args):
self.set_current_layer(args)

def set_current_layer(self, args):
index, current_layer_index = args
item = self.ui.container.layout().itemAt(current_layer_index)
if item:
item.widget().frame.setStyleSheet(self.css("layer_normal_style"))
if self.ui.container:
item = self.ui.container.layout().itemAt(index)
if item:
item.widget().frame.setStyleSheet(self.css("layer_highlight_style"))
def on_canvas_resize_signal(self, data):
self.logger.info("Adding to queue")
self.add_to_queue(data)

def handle_message(self, data:dict):
settings = data["settings"]
Expand Down Expand Up @@ -198,10 +179,31 @@ def __init__(self, *args, **kwargs):
self.register("image_generated_signal", self)
self.register("load_image_from_path", self)
self.register("canvas_handle_layer_click_signal", self)
self.register("update_canvas_signal", self)
self.register("set_current_layer_signal", self)
self.register("application_settings_changed_signal", self)

self.register_service("canvas_drag_pos", self.canvas_drag_pos)
self.register_service("canvas_current_active_image", self.canvas_current_active_image)

@pyqtSlot(object)
def on_set_current_layer_signal(self, args):
self.set_current_layer(args)

def set_current_layer(self, args):
index, current_layer_index = args
item = self.ui.container.layout().itemAt(current_layer_index)
if item:
item.widget().frame.setStyleSheet(self.css("layer_normal_style"))
if self.ui.container:
item = self.ui.container.layout().itemAt(index)
if item:
item.widget().frame.setStyleSheet(self.css("layer_highlight_style"))

@pyqtSlot(object)
def on_update_canvas_signal(self, _ignore):
self.update()

def canvas_drag_pos(self):
return self.drag_pos

Expand Down Expand Up @@ -241,8 +243,8 @@ def on_image_generated_signal(self, image_data: dict):

def on_CanvasResizeWorker_response_signal(self, line_data: tuple):
draw_grid = self.settings["grid_settings"]["show_grid"]
print("on_CanvasResizeWorker_response_signal", draw_grid)
if not draw_grid:
print("not draw_grid")
return
line = self.scene.addLine(*line_data)
self.line_group.addToGroup(line)
Expand Down Expand Up @@ -315,15 +317,20 @@ def current_image(self):
return Image.fromqpixmap(pixmap)

def handle_resize_canvas(self):
self.do_resize_canvas()

def do_resize_canvas(self):
if not self.view:
self.logger.warning("view not found")
return
self.canvas_resize_worker.add_to_queue(dict(
data = dict(
settings=self.settings,
view_size=self.view.viewport().size(),
scene=self.scene,
line_group=self.line_group
))
)
#self.emit("canvas_resize_signal", data)
self.canvas_resize_worker.add_to_queue(data)

def window_resized(self, event):
self.handle_resize_canvas()
Expand Down Expand Up @@ -407,6 +414,8 @@ def wheelEvent(self, event):

def on_application_settings_changed_signal(self):
do_draw = False

self.do_resize_canvas()

grid_settings = self.settings["grid_settings"]
for k,v in grid_settings.items():
Expand Down
2 changes: 0 additions & 2 deletions src/airunner/widgets/generator_form/generator_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,9 @@ def do_generate(self, extra_options=None, seed=None, do_deterministic=False, ove


# get the model from the database
print(model_data, self.generator_settings["model"])
name = model_data["name"] if "name" in model_data else self.generator_settings["model"]
model = self.get_service("ai_model_by_name")(name)

print("MODEL:", model, name)
# set the model data, first using model_data pulled from the override_data
model_data = dict(
name=model_data.get("name", model["name"]),
Expand Down
104 changes: 34 additions & 70 deletions src/airunner/widgets/model_manager/custom_widget.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,26 @@
import os

from airunner.models.modeldata import ModelData
from airunner.service_locator import ServiceLocator
from airunner.widgets.base_widget import BaseWidget
from airunner.widgets.model_manager.model_widget import ModelWidget
from airunner.widgets.model_manager.templates.custom_ui import Ui_custom_model_widget
from airunner.workers.worker import Worker

from PyQt6 import QtWidgets
from airunner.aihandler.logger import Logger

logger = Logger(prefix="CustomModelWidget")


class CustomModelWidget(BaseWidget):
initialized = False
widget_class_ = Ui_custom_model_widget
model_widgets = []

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.show_items_in_scrollarea()
self.scan_for_models()
self.initialized = True

def action_button_clicked_scan_for_models(self):
class ModelScannerWorker(Worker):
def handle_message(self, _message):
self.scan_for_models()

def scan_for_models(self):
self.logger.info("Scan for models")
# look at model path and determine if we can import existing local models
# first look at all files and folders inside of the model paths
base_model_path = self.path_settings["base_path"]
txt2img_model_path = self.path_settings["txt2img_model_path"]
depth2img_model_path = self.path_settings["depth2img_model_path"]
pix2pix_model_path = self.path_settings["pix2pix_model_path"]
Expand All @@ -38,6 +30,7 @@ def scan_for_models(self):
llm_casuallm_model_path = self.path_settings["llm_casuallm_model_path"]
llm_seq2seq_model_path = self.path_settings["llm_seq2seq_model_path"]
diffusers_folders = ["scheduler", "text_encoder", "tokenizer", "unet", "vae"]
models = []
for key, model_path in {
"txt2img": txt2img_model_path,
"depth2img": depth2img_model_path,
Expand Down Expand Up @@ -65,7 +58,7 @@ def scan_for_models(self):
model.category = "stablediffusion"
model.enabled = True
model.pipeline_action = key
model.pipeline_class = self.get_service("get_pipeline_classname")(
model.pipeline_class = ServiceLocator.get("get_pipeline_classname")(
model.pipeline_action, model.version, model.category
)

Expand All @@ -85,25 +78,36 @@ def scan_for_models(self):
model.name = entry.name

if model:
self.save_model(model)
models.append(dict(
name=model.name,
path=model.path,
branch=model.branch,
version=model.version,
category=model.category,
pipeline_action=model.pipeline_action,
enabled=model.enabled,
is_default=False
))

self.show_items_in_scrollarea()
self.update_generator_model_dropdown()

def save_model(self, model):
self.emit("ai_model_save_or_update_signal", dict(
name=model.name,
path=model.path,
branch=model.branch,
version=model.version,
category=model.category,
pipeline_action=model.pipeline_action,
enabled=model.enabled,
is_default=False
))

self.emit("ai_models_save_or_update_signal", models)


class CustomModelWidget(BaseWidget):
initialized = False
widget_class_ = Ui_custom_model_widget
model_widgets = []
spacer = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.show_items_in_scrollarea()
self.initialized = True
self.model_scanner_worker = self.create_worker(ModelScannerWorker)
self.model_scanner_worker.add_to_queue("scan_for_models")

def action_button_clicked_scan_for_models(self):
self.model_scanner_worker.add_to_queue("scan_for_models")

def show_items_in_scrollarea(self, search=None):
if self.spacer:
self.ui.scrollAreaWidgetContents.layout().removeItem(self.spacer)
Expand Down Expand Up @@ -144,46 +148,6 @@ def show_items_in_scrollarea(self, search=None):
self.spacer = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding)
self.ui.scrollAreaWidgetContents.layout().addItem(self.spacer)

def models_changed(self, key, model, value):
model["enabled"] = True
self.emit("model_save_or_update_signal", model)
self.update_generator_model_dropdown()

def handle_delete_model(self, model):
self.emit("ai_model_delete_signal", model)
self.show_items_in_scrollarea()
self.update_generator_model_dropdown()

def update_generator_model_dropdown(self):
if self.initialized:
self.emit("refresh_available_models")

def handle_edit_model(self, model, index):
print("edit button clicked", index)
self.toggle_model_form_frame(show=True)

categories = self.get_service("ai_model_categories")()
self.ui.model_form.category.clear()
self.ui.model_form.category.addItems(categories)
self.ui.model_form.category.setCurrentText(model.category)

actions = self.get_service("ai_model_pipeline_actions")()
self.ui.model_form.pipeline_action.clear()
self.ui.model_form.pipeline_action.addItems(actions)
self.ui.model_form.pipeline_action.setCurrentText(model.pipeline_action)

self.ui.model_form.model_name.setText(model.name)
pipeline_class = self.get_service("get_pipeline_classname")(
model.pipeline_action, model.version, model.category)
self.ui.model_form.pipeline_class_line_edit.setText(pipeline_class)
self.ui.model_form.enabled.setChecked(True)
self.ui.model_form.path_line_edit.setText(model.path)

versions = self.get_service("ai_model_versions")()
self.ui.model_form.versions.clear()
self.ui.model_form.versions.addItems(versions)
self.ui.model_form.versions.setCurrentText(model.version)

def mode_type_changed(self, val):
print("mode_type_changed", val)

Expand Down
1 change: 0 additions & 1 deletion src/airunner/widgets/model_manager/import_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def download_callback(self, current_size, total_size):

def import_models(self):
url = self.ui.import_url.text()
print("IMPORT MODELS")
try:
model_id = url.split("models/")[1]
except IndexError:
Expand Down
38 changes: 38 additions & 0 deletions src/airunner/widgets/model_manager/model_manager_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,41 @@ def add_new_model(self):

def tab_changed(self, val):
print("tab_changed", val)

def models_changed(self, key, model, value):
model["enabled"] = True
self.update_generator_model_dropdown()

def handle_delete_model(self, model):
self.emit("ai_model_delete_signal", model)
self.show_items_in_scrollarea()
self.update_generator_model_dropdown()

def update_generator_model_dropdown(self):
self.ui.generator_model_dropdown.clear()
self.ui.generator_model_dropdown.addItems(self.settings["ai_models"])

def handle_edit_model(self, model, index):
self.toggle_model_form_frame(show=True)

categories = self.get_service("ai_model_categories")()
self.ui.model_form.category.clear()
self.ui.model_form.category.addItems(categories)
self.ui.model_form.category.setCurrentText(model.category)

actions = self.get_service("ai_model_pipeline_actions")()
self.ui.model_form.pipeline_action.clear()
self.ui.model_form.pipeline_action.addItems(actions)
self.ui.model_form.pipeline_action.setCurrentText(model.pipeline_action)

self.ui.model_form.model_name.setText(model.name)
pipeline_class = self.get_service("get_pipeline_classname")(
model.pipeline_action, model.version, model.category)
self.ui.model_form.pipeline_class_line_edit.setText(pipeline_class)
self.ui.model_form.enabled.setChecked(True)
self.ui.model_form.path_line_edit.setText(model.path)

versions = self.get_service("ai_model_versions")()
self.ui.model_form.versions.clear()
self.ui.model_form.versions.addItems(versions)
self.ui.model_form.versions.setCurrentText(model.version)
Loading