Skip to content

Commit

Permalink
Merge pull request #942 from Capsize-Games/devastator
Browse files Browse the repository at this point in the history
Devastator
  • Loading branch information
w4ffl35 authored Oct 13, 2024
2 parents e740907 + 95e4bb4 commit ceb767e
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 72 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="airunner",
version="3.0.20",
version="3.0.21",
author="Capsize LLC",
description="A Stable Diffusion GUI",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand Down
3 changes: 3 additions & 0 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class SignalCode(Enum):
HISTORY_UPDATED = enum.auto()
CANVAS_IMAGE_UPDATED_SIGNAL = enum.auto()

UNLOAD_NON_SD_MODELS = enum.auto()
LOAD_NON_SD_MODELS = enum.auto()

class EngineResponseCode(Enum):
STATUS = 100
ERROR = 200
Expand Down
5 changes: 5 additions & 0 deletions src/airunner/handlers/llm/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def run(
is_first_message=True,
is_end_of_message=True,
name=self.botname,
action=action
)
)

Expand Down Expand Up @@ -614,6 +615,7 @@ def run_with_thread(
is_first_message=is_first_message,
is_end_of_message=False,
name=self.botname,
action=LLMActionType.CHAT
)
)
is_first_message = False
Expand Down Expand Up @@ -668,6 +670,7 @@ def run_with_thread(
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
action=action
)
)
else:
Expand All @@ -678,6 +681,7 @@ def run_with_thread(
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
action=action
)
)
is_first_message = False
Expand All @@ -702,6 +706,7 @@ def run_with_thread(
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
action=action
)
)
is_first_message = False
Expand Down
5 changes: 4 additions & 1 deletion src/airunner/handlers/stablediffusion/sd_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,10 @@ def _load_embeddings(self):
self.logger.error("Pipe is None, unable to load embeddings")
return
self.logger.debug("Loading embeddings")
self._pipe.unload_textual_inversion()
try:
self._pipe.unload_textual_inversion()
except RuntimeError as e:
self.logger.error(f"Failed to unload embeddings: {e}")
session = self.db_handler.get_db_session()
embeddings = session.query(Embedding).filter_by(
version=self.generator_settings_cached.version
Expand Down
59 changes: 16 additions & 43 deletions src/airunner/widgets/generator_form/generator_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from airunner.data.models.settings_models import ShortcutKeys
from airunner.enums import SignalCode, GeneratorSection, ImageCategory, ImagePreset, StableDiffusionVersion, \
ModelStatus, ModelType
ModelStatus, ModelType, LLMActionType
from airunner.mediator_mixin import MediatorMixin
from airunner.settings import PHOTO_REALISTIC_NEGATIVE_PROMPT, ILLUSTRATION_NEGATIVE_PROMPT
from airunner.utils.random_seed import random_seed
Expand Down Expand Up @@ -185,15 +185,15 @@ def on_llm_image_prompt_generated_signal(self, data):
message="Your image is generating...",
is_first_message=True,
is_end_of_message=True,
name=self.chatbot.name
name=self.chatbot.name,
action=LLMActionType.GENERATE_IMAGE
)
)

# Unload the LLM
if self.application_settings.llm_enabled:
self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict(
callback=self.unload_llm_callback
))
# Unload non-Stable Diffusion models
self.emit_signal(SignalCode.UNLOAD_NON_SD_MODELS, dict(
callback=self.unload_llm_callback
))

# Set the prompts in the generator form UI
data = self.extract_json_from_message(data["message"])
Expand Down Expand Up @@ -238,45 +238,18 @@ def finalize_image_generated_by_llm(self, data):
message="Your image has been generated",
is_first_message=True,
is_end_of_message=True,
name=self.chatbot.name
name=self.chatbot.name,
action=LLMActionType.GENERATE_IMAGE
)

# If SD is enabled, emit a signal to unload SD.
if self.application_settings.sd_enabled:
# If LLM is disabled, emit a signal to load it.
if not self.application_settings.llm_enabled:
self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict(
callback=lambda d: self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict(
callback=lambda d: self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
image_generated_message
)
))
))
else:
self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict(
callback=lambda d: self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
image_generated_message
)
))
else:
# If SD is disabled and LLM is disabled, emit a signal to load LLM
# with a callback to add the image generated message to the conversation.
if not self.application_settings.llm_enabled:
self.emit_signal(SignalCode.TOGGLE_LLM_SIGNAL, dict(
callback=lambda d: self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
image_generated_message
)
))
else:
# If SD is disabled and LLM is enabled, emit a signal to add
# the image generated message to the conversation.
self.emit_signal(
self.emit_signal(SignalCode.TOGGLE_SD_SIGNAL, dict(
callback=lambda d: self.emit_signal(SignalCode.LOAD_NON_SD_MODELS, dict(
callback=lambda d: self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
image_generated_message
)
))
))
##########################################################################
# End LLM Generated Image handlers
##########################################################################
Expand Down Expand Up @@ -382,10 +355,10 @@ def extract_json_from_message(self, message):
json_dict = json.loads(json_block)
return json_dict
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
self.logger.error(f"Error decoding JSON block: {e}")
return {}
else:
print("No JSON block found in the message.")
self.logger.error("No JSON block found in message")
return {}

def get_memory_options(self):
Expand Down
18 changes: 18 additions & 0 deletions src/airunner/windows/main/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def register_signals(self):
(SignalCode.TOGGLE_TTS_SIGNAL, self.on_toggle_tts),
(SignalCode.TOGGLE_SD_SIGNAL, self.on_toggle_sd),
(SignalCode.TOGGLE_LLM_SIGNAL, self.on_toggle_llm),
(SignalCode.UNLOAD_NON_SD_MODELS, self.on_unload_non_sd_models),
(SignalCode.LOAD_NON_SD_MODELS, self.on_load_non_sd_models),
(SignalCode.APPLICATION_RESET_SETTINGS_SIGNAL, self.action_reset_settings),
(SignalCode.APPLICATION_RESET_PATHS_SIGNAL, self.on_reset_paths_signal),
(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal),
Expand Down Expand Up @@ -682,6 +684,22 @@ def on_toggle_fullscreen_signal(self):
else:
self.showFullScreen()

def on_unload_non_sd_models(self, data:dict=None):
self._llm_generate_worker.on_llm_on_unload_signal()
self._tts_generator_worker.unload()
self._stt_audio_processor_worker.unload()
callback = data.get("callback", None)
if callback:
callback(data)

def on_load_non_sd_models(self, data:dict=None):
self._llm_generate_worker.load()
self._tts_generator_worker.load()
self._stt_audio_processor_worker.load()
callback = data.get("callback", None)
if callback:
callback(data)

def on_toggle_llm(self, data:dict=None, val=None):
if val is None:
val = not self.application_settings.llm_enabled
Expand Down
4 changes: 3 additions & 1 deletion src/airunner/workers/agent_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import traceback
import torch

from airunner.enums import SignalCode
from airunner.enums import SignalCode, LLMActionType
from airunner.workers.worker import Worker


Expand Down Expand Up @@ -44,6 +44,7 @@ def handle_message(self, message):
is_first_message=True,
is_end_of_message=True,
name=message["botname"],
action=LLMActionType.CHAT
)
)
else:
Expand All @@ -58,6 +59,7 @@ def handle_message(self, message):
is_first_message=True,
is_end_of_message=True,
name=message["botname"],
action=LLMActionType.CHAT
)
)

Expand Down
45 changes: 31 additions & 14 deletions src/airunner/workers/audio_capture_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from PySide6.QtCore import QThread

from airunner.enums import SignalCode
from airunner.enums import SignalCode, ModelStatus
from airunner.settings import SLEEP_TIME_IN_MS
from airunner.workers.worker import Worker

Expand All @@ -21,6 +21,7 @@ def __init__(self):
(SignalCode.AUDIO_CAPTURE_WORKER_RESPONSE_SIGNAL, self.on_AudioCaptureWorker_response_signal),
(SignalCode.STT_START_CAPTURE_SIGNAL, self.on_stt_start_capture_signal),
(SignalCode.STT_STOP_CAPTURE_SIGNAL, self.on_stt_stop_capture_signal),
(SignalCode.MODEL_STATUS_CHANGED_SIGNAL, self.on_model_status_changed_signal),
))
self.listening: bool = False
self.voice_input_start_time: time.time = None
Expand All @@ -29,9 +30,6 @@ def __init__(self):
self.stream = None
self.running = False
self._audio_process_queue = queue.Queue()
#self._capture_thread = None
if self.application_settings.stt_enabled:
self._start_listening()

def on_AudioCaptureWorker_response_signal(self, message: dict):
item: np.ndarray = message["item"]
Expand All @@ -46,6 +44,14 @@ def on_stt_stop_capture_signal(self):
if self.listening:
self._stop_listening()

def on_model_status_changed_signal(self, message: dict):
model = message["model"]
status = message["status"]
if model == "stt" and status is ModelStatus.LOADED:
self._start_listening()
elif model == "stt" and status in (ModelStatus.UNLOADED, ModelStatus.FAILED):
self._stop_listening()

def start(self):
self.logger.debug("Starting audio capture worker")
self.running = True
Expand All @@ -64,6 +70,10 @@ def start(self):
self.logger.error(f"PortAudioError: {e}")
QThread.msleep(SLEEP_TIME_IN_MS)
continue
except Exception as e:
self.logger.error(e)
QThread.msleep(SLEEP_TIME_IN_MS)
continue
if np.max(np.abs(chunk)) > volume_input_threshold: # check if chunk is not silence
self.logger.debug("Heard voice")
is_receiving_input = True
Expand Down Expand Up @@ -92,21 +102,19 @@ def start(self):

def _start_listening(self):
self.logger.debug("Start listening")
if self.stream is not None:
self._end_stream()
self._initialize_stream()
self.listening = True
fs = self.stt_settings.fs
channels = self.stt_settings.channels
if self.stream is None:
self.stream = sd.InputStream(samplerate=fs, channels=channels)

try:
self.stream.start()
except Exception as e:
self.logger.error(e)

def _stop_listening(self):
self.logger.debug("Stop listening")
self.listening = False
self.running = False
self._end_stream()
# self._capture_thread.join()

def _end_stream(self):
try:
self.stream.stop()
except Exception as e:
Expand All @@ -115,4 +123,13 @@ def _stop_listening(self):
self.stream.close()
except Exception as e:
self.logger.error(e)
# self._capture_thread.join()
self.stream = None

def _initialize_stream(self):
fs = self.stt_settings.fs
channels = self.stt_settings.channels
self.stream = sd.InputStream(samplerate=fs, channels=channels)
try:
self.stream.start()
except Exception as e:
self.logger.error(e)
21 changes: 17 additions & 4 deletions src/airunner/workers/audio_processor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ def __init__(self):
))

def start_worker_thread(self):
self._stt = WhisperHandler()
self._initialize_stt_handler()
if self.application_settings.stt_enabled:
self._stt.load()

def _initialize_stt_handler(self):
if self._stt is None:
self._stt = WhisperHandler()

def on_stt_load_signal(self):
if self._stt:
threading.Thread(target=self._stt_load).start()
Expand All @@ -34,13 +38,22 @@ def on_stt_unload_signal(self):
if self._stt:
threading.Thread(target=self._stt_unload).start()

def unload(self):
self._stt_unload()

def load(self):
self._initialize_stt_handler()
self._stt_load()

def _stt_load(self):
self._stt.load()
self.emit_signal(SignalCode.STT_START_CAPTURE_SIGNAL)
if self._stt:
self._stt.load()
self.emit_signal(SignalCode.STT_START_CAPTURE_SIGNAL)

def _stt_unload(self):
self.emit_signal(SignalCode.STT_STOP_CAPTURE_SIGNAL)
self._stt.unload()
if self._stt:
self._stt.unload()

def on_stt_process_audio_signal(self, message):
self.add_to_queue(message)
Expand Down
8 changes: 6 additions & 2 deletions src/airunner/workers/llm_generate_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def on_quit_application_signal(self):
def on_llm_request_worker_response_signal(self, message: dict):
self.add_to_queue(message)

def on_llm_on_unload_signal(self, data):
def on_llm_on_unload_signal(self, data=None):
data = data or {}
self.logger.debug("Unloading LLM")
self.llm.unload()
callback = data.get("callback", None)
Expand Down Expand Up @@ -80,7 +81,10 @@ def _load_llm_thread(self, data=None):
self._llm_thread = threading.Thread(target=self._load_llm, args=(data,))
self._llm_thread.start()

def _load_llm(self, data):
def load(self):
self._load_llm()

def _load_llm(self, data=None):
data = data or {}
if self.llm is None:
self.llm = CausalLMTransformerBaseHandler(agent_options=self.agent_options)
Expand Down
Loading

0 comments on commit ceb767e

Please sign in to comment.