Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #381

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/airunner/widgets/base_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from PyQt6 import QtGui
from PyQt6.QtWidgets import QWidget
from PyQt6.QtCore import QThread

from airunner.utils import get_main_window

Expand All @@ -21,6 +22,19 @@ def is_dark(self):
def canvas(self):
return self.app.canvas

threads = []
def create_worker(self, worker_class_, response_signal_slot):
prefix = worker_class_.__name__
worker = worker_class_(prefix=prefix)
worker_thread = QThread()
worker.moveToThread(worker_thread)
worker.response_signal.connect(response_signal_slot)
worker.finished.connect(worker_thread.quit)
worker_thread.started.connect(worker.start)
worker_thread.start()
self.threads.append(worker_thread)
return worker

def add_to_grid(self, widget, row, column, row_span=1, column_span=1):
self.layout().addWidget(widget, row, column, row_span, column_span)

Expand Down
156 changes: 100 additions & 56 deletions src/airunner/widgets/canvas_plus/canvas_plus_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PyQt6.QtWidgets import QGraphicsPixmapItem
from PyQt6 import QtWidgets, QtCore
from PyQt6.QtWidgets import QGraphicsItemGroup
from PyQt6.QtCore import pyqtSlot, QThread
from PyQt6.QtCore import pyqtSlot, pyqtSignal, QThread
from airunner.aihandler.enums import EngineResponseCode

from airunner.workers.image_data_worker import ImageDataWorker
Expand All @@ -20,6 +20,58 @@
from airunner.widgets.canvas_plus.draggables import DraggablePixmap, ActiveGridArea
from airunner.widgets.canvas_plus.custom_scene import CustomScene
from airunner.widgets.base_widget import BaseWidget
from airunner.workers.worker import Worker


class CanvasResizeWorker(Worker):
response_signal = pyqtSignal(tuple)
do_draw_signal = pyqtSignal()
clear_lines_signal = pyqtSignal()
queue_type = "get_last_item"

def handle_message(self, data):
self.draw_lines(data)

def draw_lines(self, data:dict):
self.clear_lines_signal.emit()
settings = data["settings"]
view_size = data["view_size"]

cell_size = settings["grid_settings"]["cell_size"]
line_color = settings["grid_settings"]["line_color"]
line_width = settings["grid_settings"]["line_width"]

width_cells = math.ceil(view_size.width() / cell_size)
height_cells = math.ceil(view_size.height() / cell_size)
pen = QPen(
QBrush(QColor(line_color)),
line_width,
Qt.PenStyle.SolidLine
)

# for line_data in lines:

# vertical lines
h = view_size.height() + abs(settings["canvas_settings"]["pos_y"]) % cell_size
y = 0
for i in range(width_cells):
x = i * cell_size + settings["canvas_settings"]["pos_x"] % cell_size
line_data = (x, y, x, h, pen)
# line = scene.addLine(x, y, x, h, pen)
# line_group.addToGroup(line)
self.response_signal.emit(line_data)

# # horizontal lines
w = view_size.width() + abs(settings["canvas_settings"]["pos_x"]) % cell_size
x = 0
for i in range(height_cells):
y = i * cell_size + settings["canvas_settings"]["pos_y"] % cell_size
line_data = (x, y, w, y, pen)
# line = scene.addLine(x, y, w, y, pen)
# line_group.addToGroup(line)
self.response_signal.emit(line_data)

self.do_draw_signal.emit()


class CanvasPlusWidget(BaseWidget):
Expand Down Expand Up @@ -112,18 +164,43 @@ def __init__(self, *args, **kwargs):
self._zoom_level = 1
self.canvas_container.resizeEvent = self.window_resized

self.image_data_worker = ImageDataWorker(prefix="ImageDataWorker")
self.image_data_worker_thread = QThread()
self.image_data_worker.moveToThread(self.image_data_worker_thread)
self.image_data_worker.response_signal.connect(self.image_data_worker_response_signal_slot)
self.image_data_worker.finished.connect(self.image_data_worker_thread.quit)
self.image_data_worker_thread.started.connect(self.image_data_worker.start)
self.image_data_worker_thread.start()
self.image_data_worker = self.create_worker(
ImageDataWorker,
self.image_data_worker_response_signal_slot
)
self.canvas_resize_worker = self.create_worker(
CanvasResizeWorker,
self.canvas_resize_worker_response_signal_slot
)
self.canvas_resize_worker.do_draw_signal.connect(self.do_draw_signal_slot)
self.canvas_resize_worker.clear_lines_signal.connect(self.clear_lines_slot)

@pyqtSlot()
def clear_lines_slot(self):
self.clear_lines()

@pyqtSlot()
def do_draw_signal_slot(self):
self.do_draw()

@pyqtSlot(dict)
def handle_image_data(self, image_data):
self.image_data_worker.add_to_queue(image_data)

@pyqtSlot(tuple)
def canvas_resize_worker_response_signal_slot(self, line_data):
# self.app.clear_status_message()
# self.app.stop_progress_bar()
# self.app.show_layers()
# self.app.set_status_label(f"Image resized")
# self.redraw_lines = True
# self.do_draw()
draw_grid = self.app.settings["grid_settings"]["show_grid"]
if not draw_grid:
return
line = self.scene.addLine(*line_data)
self.line_group.addToGroup(line)

@pyqtSlot()
def image_data_worker_response_signal_slot(self, message):
self.app.clear_status_message()
Expand Down Expand Up @@ -195,9 +272,18 @@ def current_image(self):
return None
return Image.fromqpixmap(pixmap)

def handle_resize_canvas(self):
if not self.view:
return
self.canvas_resize_worker.add_to_queue(dict(
settings=self.app.settings,
view_size=self.view.viewport().size(),
scene=self.scene,
line_group=self.line_group
))

def window_resized(self, event):
self.redraw_lines = True
self.do_draw()
self.handle_resize_canvas()

def toggle_grid(self, val):
self.do_draw()
Expand Down Expand Up @@ -325,7 +411,7 @@ def handle_mouse_event(self, original_mouse_event, event):

def resizeEvent(self, event):
if self.view:
self.do_draw()
self.handle_resize_canvas()
if self.scene:
self.scene.resize()

Expand Down Expand Up @@ -384,35 +470,6 @@ def clear_lines(self):
self.scene.removeItem(self.line_group)
self.line_group = QGraphicsItemGroup()

def draw_lines(self):
width_cells = math.ceil(self.view_size.width() / self.cell_size)
height_cells = math.ceil(self.view_size.height() / self.cell_size)

pen = QPen(
QBrush(QColor(self.line_color)),
self.line_width,
Qt.PenStyle.SolidLine
)

# vertical lines
h = self.view_size.height() + abs(self.app.settings["canvas_settings"]["pos_y"]) % self.cell_size
y = 0
for i in range(width_cells):
x = i * self.cell_size + self.app.settings["canvas_settings"]["pos_x"] % self.cell_size
line = self.scene.addLine(x, y, x, h, pen)
self.line_group.addToGroup(line)

# # horizontal lines
w = self.view_size.width() + abs(self.app.settings["canvas_settings"]["pos_x"]) % self.cell_size
x = 0
for i in range(height_cells):
y = i * self.cell_size + self.app.settings["canvas_settings"]["pos_y"] % self.cell_size
line = self.scene.addLine(x, y, w, y, pen)
self.line_group.addToGroup(line)

# Add the group to the scene
self.scene.addItem(self.line_group)

def draw_active_grid_area_container(self):
"""
Draw a rectangle around the active grid area of
Expand Down Expand Up @@ -441,28 +498,16 @@ def do_draw(self):
self.view_size = self.view.viewport().size()
self.set_scene_rect()
self.draw_grid()
self.draw_layers()
self.draw_active_grid_area_container()
#self.draw_layers()
#self.draw_active_grid_area_container()
self.ui.canvas_position.setText(
f"X {-self.app.settings['canvas_settings']['pos_x']: 05d} Y {self.app.settings['canvas_settings']['pos_y']: 05d}"
)
self.scene.update()
self.drawing = False

def draw_grid(self):
draw_grid = self.app.settings["grid_settings"]["show_grid"]

if draw_grid and self.redraw_lines:
self.clear_lines()
self.has_lines = False
self.redraw_lines = False

if draw_grid and not self.has_lines:
self.draw_lines()
self.has_lines = True
elif not draw_grid and self.has_lines:
self.clear_lines()
self.has_lines = False
self.scene.addItem(self.line_group)

def handle_image_data(self, data):
options = data["data"]["options"]
Expand Down Expand Up @@ -534,7 +579,6 @@ def handle_outpaint(self, outpaint_box_rect, outpainted_image, action=None):
return new_image, image_root_point, image_pivot_point

def load_image_from_path(self, image_path):
print("canvas_plus_widget load_image_from_path", image_path)
if image_path is None or image_path == "":
return
image = Image.open(image_path)
Expand Down
8 changes: 0 additions & 8 deletions src/airunner/widgets/canvas_plus/standard_image_widget.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import QLabel
from PyQt6.QtGui import QPixmap
from PyQt6.QtWidgets import QVBoxLayout
from PyQt6.QtWidgets import QDialog
from PyQt6.QtGui import QImage

from PIL import Image
from PyQt6.QtCore import pyqtSlot, QThread

from airunner.widgets.canvas_plus.templates.standard_image_widget_ui import Ui_standard_image_widget
from airunner.utils import load_metadata_from_image, prepare_metadata
from airunner.widgets.slider.slider_widget import SliderWidget
from airunner.aihandler.logger import Logger
from airunner.widgets.base_widget import BaseWidget
from airunner.workers.worker import Worker


class ImageDataWorker(Worker):
def handle_message(self, message):
pass


class StandardImageWidget(BaseWidget):
Expand Down
2 changes: 1 addition & 1 deletion src/airunner/widgets/llm/llm_settings_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def set_dtype_description(self, dtype):
def update_model_version_combobox(self):
self.ui.model_version.blockSignals(True)
self.ui.model_version.clear()
ai_model_paths = self.app.ai_model_paths()
ai_model_paths = self.app.ai_model_paths(model_type="llm", pipeline_action=self.ui.model.currentText())
self.ui.model_version.addItems(ai_model_paths)
self.ui.model_version.blockSignals(False)

Expand Down
10 changes: 8 additions & 2 deletions src/airunner/windows/main/ai_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ def ai_model_save_or_update(self, model_data):
else:
self.ai_model_create(model_data)

def ai_model_paths(self):
return [model["path"] for model in self.settings["ai_models"]]
def ai_model_paths(self, model_type=None, pipeline_action=None):
models = self.settings["ai_models"]
if model_type:
models = [model for model in models if "model_type" in model and model["model_type"] == model_type]
if pipeline_action:
models = [model for model in models if model["pipeline_action"] == pipeline_action]

return [model["path"] for model in models]

def ai_model_categories(self):
return [model["category"] for model in self.settings["ai_models"]]
Expand Down
32 changes: 28 additions & 4 deletions src/airunner/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class Worker(QtCore.QObject):
response_signal = QtCore.pyqtSignal(dict)
finished = QtCore.pyqtSignal()
queue_type = "get_next_item"

def __init__(self, prefix="Worker"):
super().__init__()
Expand All @@ -23,17 +24,40 @@ def start(self):
self.running = True
while self.running:
try:
index = self.queue.get(timeout=0.1)
msg = self.items.pop(index, None)
# if self.queue has more than one item, scrap everything other than the last item that
# was added to the queue
msg = self.get_item_from_queue()
if msg:
self.handle_message(msg)
except queue.Empty:
msg = None
if msg is not None:
self.handle_message(msg)
if self.paused:
self.logger.info("Paused")
while self.paused:
QtCore.QThread.msleep(100)
self.logger.info("Resumed")
QtCore.QThread.msleep(100)

def get_item_from_queue(self):
if self.queue_type == "get_last_item":
msg = self.get_last_item()
else:
msg = self.get_next_item()
return msg

def get_last_item(self):
msg = None
while not self.queue.empty():
index = self.queue.get(timeout=0.1)
if index is not None:
msg = self.items.pop(index, None)
return msg

def get_next_item(self):
index = self.queue.get(timeout=0.1)
msg = self.items.pop(index, None)
return msg


def pause(self):
self.paused = True
Expand Down