Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for speech to text #386

Merged
merged 8 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 112 additions & 199 deletions src/airunner/aihandler/engine.py

Large diffs are not rendered by default.

23 changes: 0 additions & 23 deletions src/airunner/aihandler/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,6 @@ class FilterType(Enum):
PIXEL_ART = "pixelart"


class EngineResponseCode(Enum):
STATUS = 100
ERROR = 200
WARNING = 300
PROGRESS = 400
IMAGE_GENERATED = 500
CONTROLNET_IMAGE_GENERATED = 501
MASK_IMAGE_GENERATED = 502
EMBEDDING_LOAD_FAILED = 600
TEXT_GENERATED = 700
TEXT_STREAMED = 701
CAPTION_GENERATED = 800
ADD_TO_CONVERSATION = 900
CLEAR_MEMORY = 1000
NSFW_CONTENT_DETECTED = 1100


class EngineRequestCode(Enum):
GENERATE_IMAGE = 100
GENERATE_TEXT = 200
GENERATE_CAPTION = 300


class Scheduler(Enum):
EULER_ANCESTRAL = "Euler a"
EULER = "Euler"
Expand Down
86 changes: 47 additions & 39 deletions src/airunner/aihandler/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,45 @@
from transformers import TextIteratorStreamer

from PyQt6.QtCore import QObject
from PyQt6.QtCore import pyqtSignal, pyqtSlot, QThread

from airunner.aihandler.enums import EngineResponseCode
from airunner.aihandler.logger import Logger
from airunner.workers.worker import Worker
from airunner.mediator_mixin import MediatorMixin

class GenerateWorker(Worker):
def __init__(self, prefix):
class LLMGenerateWorker(Worker):
def __init__(self, prefix="LLMGenerateWorker"):
self.llm = LLM()
super().__init__(prefix=prefix)
self.register("clear_history", self)

def handle_message(self, message):
for response in self.llm.do_generate(message):
self.response_signal.emit(response)
self.emit("llm_text_streamed_signal", response)

def on_clear_history(self):
self.llm.clear_history()


class LLMRequestWorker(Worker):
def __init__(self, prefix="LLMRequestWorker"):
super().__init__(prefix=prefix)

def handle_message(self, message):
super().handle_message(message)

class LLMController(QObject):
logger = Logger(prefix="LLMController")
response_signal = pyqtSignal(dict)

class LLMController(QObject, MediatorMixin):

def __init__(self, *args, **kwargs):
self.engine = kwargs.pop("engine", None)
self.app = self.engine.app
MediatorMixin.__init__(self)
self.engine = kwargs.pop("engine")
super().__init__(*args, **kwargs)
self.logger = Logger(prefix="LLMController")

self.request_worker = Worker(prefix="LLM Request Worker")
self.request_worker_thread = QThread()
self.request_worker.moveToThread(self.request_worker_thread)
self.request_worker.response_signal.connect(self.request_worker_response_signal_slot)
self.request_worker.finished.connect(self.request_worker_thread.quit)
self.request_worker_thread.started.connect(self.request_worker.start)
self.request_worker_thread.start()

self.generate_worker = GenerateWorker(prefix="LLM Generate Worker")
self.generate_worker_thread = QThread()
self.generate_worker.moveToThread(self.generate_worker_thread)
self.generate_worker.response_signal.connect(self.generate_worker_response_signal_slot)
self.generate_worker.finished.connect(self.generate_worker_thread.quit)
self.generate_worker_thread.started.connect(self.generate_worker.start)
self.generate_worker_thread.start()
self.request_worker = self.create_worker(LLMRequestWorker)
self.generate_worker = self.create_worker(LLMGenerateWorker)
self.register("LLMRequestWorker_response_signal", self)
self.register("LLMGenerateWorker_response_signal", self)

def pause(self):
self.request_worker.pause()
Expand All @@ -63,20 +62,21 @@ def resume(self):
def do_request(self, message):
self.request_worker.add_to_queue(message)

@pyqtSlot(dict)
def request_worker_response_signal_slot(self, message):
def clear_history(self):
self.emit("clear_history")

def on_LLMRequestWorker_response_signal(self, message):
self.generate_worker.add_to_queue(message)

@pyqtSlot(dict)
def generate_worker_response_signal_slot(self, message):
self.response_signal.emit(message)
def on_LLMGenerateWorker_response_signal(self, message:dict):
self.emit("llm_controller_response_signal", message)

def do_unload_llm(self):
self.generate_worker.llm.unload_model()
self.generate_worker.llm.unload_tokenizer()


class LLM(QObject):
class LLM(QObject, MediatorMixin):
logger = Logger(prefix="LLM")
dtype = ""
local_files_only = True
Expand Down Expand Up @@ -138,11 +138,11 @@ def has_gpu(self):
if self.dtype == "32bit" or not self.use_gpu:
return False
return torch.cuda.is_available()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self.llm_api = LLMAPI(app=app)
MediatorMixin.__init__(self)

def move_to_cpu(self):
if self.model:
self.logger.info("Moving model to CPU")
Expand Down Expand Up @@ -211,7 +211,6 @@ def load_model(self, local_files_only = None):
params["quantization_config"] = config

path = self.current_model_path
# self.engine.send_message(f"Loading {self.requested_generator_name} model from {path}")

auto_class_ = None
if self.requested_generator_name == "seq2seq":
Expand Down Expand Up @@ -508,6 +507,8 @@ def generate(self):
n = 0
streamed_template = ""
replaced = False
is_end_of_message = False
is_first_message = True
for new_text in self.streamer:
# strip all newlines from new_text
parsed_new_text = new_text.replace("\n", " ")
Expand All @@ -532,15 +533,22 @@ def generate(self):
replaced = True
streamed_template = streamed_template.replace(rendered_template, "")
else:
if "</s>" in new_text:
streamed_template = streamed_template.replace("</s>", "")
new_text = new_text.replace("</s>", "")
is_end_of_message = True
yield dict(
code=EngineResponseCode.TEXT_STREAMED,
message=new_text
message=new_text,
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
)
is_first_message = False

if "</s>" in new_text:
if is_end_of_message:
self.history.append({
"role": "bot",
"content": streamed_template.replace("</s>", "").strip()
"content": streamed_template.strip()
})
streamed_template = ""
replaced = False
Expand Down
103 changes: 0 additions & 103 deletions src/airunner/aihandler/llm_api.py

This file was deleted.

9 changes: 4 additions & 5 deletions src/airunner/aihandler/mixins/embedding_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from airunner.aihandler.logger import Logger as logger
from airunner.aihandler.enums import EngineResponseCode


class EmbeddingMixin:
Expand All @@ -25,10 +24,10 @@ def load_learned_embed_in_clip(self):
self.pipe.load_textual_inversion(path, token=token, weight_name=f)
except Exception as e:
if "already in tokenizer vocabulary" not in str(e):
self.send_message({
"embedding_name": token,
"model_name": self.model,
}, EngineResponseCode.EMBEDDING_LOAD_FAILED)
self.emit("embedding_load_failed_signal", dict(
embedding_name=token,
model_name=self.model,
))
logger.warning(e)
except AttributeError as e:
if "load_textual_inversion" in str(e):
Expand Down
15 changes: 14 additions & 1 deletion src/airunner/aihandler/mixins/merge_mixin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import os

from PyQt6.QtCore import pyqtSlot

from airunner.aihandler.logger import Logger

logger = Logger(prefix="MergeMixin")


class MergeMixin:
def merge_models(self, base_model_path, models_to_merge_path, weights, output_path, name, action):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register("sd_merge_models_signal", self)

@pyqtSlot(object)
def on_sd_merge_models_signal(self, options):
print("TODO: on_sd_merge_models_signal")

@pyqtSlot(object)
def merge_models(self, options):
base_model_path, models_to_merge_path, weights, output_path, name, action = options
from diffusers import (
StableDiffusionPipeline,
StableDiffusionInstructPix2PixPipeline,
Expand Down
3 changes: 1 addition & 2 deletions src/airunner/aihandler/mixins/scheduler_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def change_scheduler(self):
def prepare_scheduler(self):
scheduler_name = self.options.get(f"scheduler", "euler_a")
if self.scheduler_name != scheduler_name:
logger.info(f"Prepare scheduler {scheduler_name}")
self.send_message("Preparing scheduler...")
self.emit("status_signal", f"Preparing scheduler {scheduler_name}")
self.scheduler_name = scheduler_name
self.do_change_scheduler = True
else:
Expand Down
Loading