Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #329

Merged
merged 9 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"omegaconf==2.3.0",
"accelerate==0.23.0",
"controlnet_aux==0.0.7",
"diffusers==0.23.1",
"huggingface-hub==0.17.3",
"diffusers==0.24.0",
"huggingface-hub==0.19.4",
"numpy==1.23.5",
"Pillow==9.5.0",
"pip==23.3.1",
Expand All @@ -37,7 +37,7 @@
"requests-oauthlib==1.3.1",
"safetensors==0.3.1",
"scipy==1.10.1",
"tokenizers==0.14.1",
"tokenizers==0.15.0",
"tqdm==4.65.0",
"charset-normalizer==3.1.0",
"opencv-python==4.8.0.74",
Expand Down
26 changes: 25 additions & 1 deletion src/airunner/aihandler/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,34 @@ def from_pretrained(requested_action, **kwargs):
requested_action,
model_data,
pipeline_action,
category
)
if class_object is None:
return None
return class_object.from_pretrained(model, **kwargs)
if "torch_dtype" in kwargs:
del kwargs["torch_dtype"]
try:
return class_object.from_pretrained(model, **kwargs)
except Exception as e:
try_again = False
if "Checkout your internet connection" in str(e):
try_again = True
elif "To enable repo look-ups" in str(e):
try_again = True
elif "No such file or directory" in str(e):
try_again = True
elif "does not appear to have a file named config.json" in str(e):
try_again = True
elif "Entry Not Found" in str(e):
try_again = True
if try_again:
kwargs["local_files_only"] = False
kwargs["class_object"] = class_object
kwargs["model_data"] = model_data
kwargs["model"] = model
kwargs["pipeline_action"] = pipeline_action
kwargs["category"] = category
return AutoImport.from_pretrained(requested_action, **kwargs)

@staticmethod
def class_object(
Expand Down
36 changes: 32 additions & 4 deletions src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import torch
import gc
import threading

from airunner.aihandler.llm import LLM
from airunner.aihandler.logger import Logger as logger
from airunner.aihandler.runner import SDRunner
from airunner.aihandler.settings_manager import SettingsManager
from airunner.aihandler.tts import TTS


class Engine:
Expand All @@ -26,14 +29,17 @@ def __init__(self, **kwargs):
self.app = kwargs.get("app", None)
self.message_var = kwargs.get("message_var", None)
self.message_handler = kwargs.get("message_handler", None)
self.llm = LLM(engine=self)
self.llm = LLM(app=self.app, engine=self)
self.sd = SDRunner(
app=self.app,
message_var=self.message_var,
message_handler=self.message_handler,
engine=self
)
self.tts = TTS()
self.settings_manager = SettingsManager()
self.tts_thread = threading.Thread(target=self.tts.run)
self.tts_thread.start()

def generator_sample(self, data: dict):
"""
Expand All @@ -43,6 +49,7 @@ def generator_sample(self, data: dict):
"""
logger.info("generator_sample called")
is_llm = self.is_llm_request(data)
is_tts = self.is_tts_request(data)
if is_llm and self.model_type != "llm":
logger.info("Switching to LLM model")
self.model_type = "llm"
Expand All @@ -55,22 +62,42 @@ def generator_sample(self, data: dict):
self.sd.unload_model()
self.sd.unload_tokenizer()
self.clear_memory()
self.llm.move_to_device()
elif not is_llm and self.model_type != "art":
# self.llm.move_to_device()
elif is_tts:
# split on sentence enders
sentence_enders = [".", "?", "!", "\n"]
text = data["request_data"]["text"]
sentences = []
# split text into sentences
current_sentence = ""
for char in text:
current_sentence += char
if char in sentence_enders:
sentences.append(current_sentence)
current_sentence = ""
if current_sentence != "":
sentences.append(current_sentence)

for sentence in sentences:
self.tts.add_sentence(sentence, "a")
elif not is_llm and not is_tts and self.model_type != "art":
logger.info("Switching to art model")
self.model_type = "art"
self.unload_llm()

if is_llm:
logger.info("Engine calling llm.do_generate")
self.llm.do_generate(data)
else:
elif not is_tts:
logger.info("Engine calling sd.generator_sample")
self.sd.generator_sample(data)

def is_llm_request(self, data):
return "llm_request" in data

def is_tts_request(self, data):
return "tts_request" in data

def unload_llm(self):
"""
This function will either leave the LLM
Expand All @@ -91,6 +118,7 @@ def unload_llm(self):
if do_move_to_cpu:
logger.info("Moving LLM to CPU")
self.llm.move_to_cpu()

self.clear_memory()
elif do_unload_model:
logger.info("Unloading LLM")
Expand Down
1 change: 0 additions & 1 deletion src/airunner/aihandler/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,5 @@ class GeneratorSection(Enum):
SUPERRESOLUTION = "superresolution"
UPSCALE = "upscale"
VID2VID = "vid2vid"
TXT2GIF = "txt2gif"
TXT2VID = "txt2vid"
PROMPT_BUILDER = "prompt_builder"
17 changes: 16 additions & 1 deletion src/airunner/aihandler/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ class LLM(TransformerRunner):
def clear_conversation(self):
if self.generator.name == "casuallm":
self.chain.clear()

def do_generate(self, data):
self.process_data(data)
self.handle_request()
self.requested_generator_name = data["request_data"]["generator_name"]
prompt = data["request_data"]["prompt"]
model_path = data["request_data"]["model_path"]
self.generate(
app=self.app,
endpoint=data["request_data"]["generator_name"],
prompt=prompt,
model=model_path,
stream=data["request_data"]["stream"],
images=[data["request_data"]["image"]],
)

def generate(self, **kwargs):
if self.generator.name == "casuallm":
Expand All @@ -23,7 +38,6 @@ def generate(self, **kwargs):

answers = []
for res in out:
print("DECODING RESULT")
answer = self.processor.decode(
res,
skip_special_tokens=True
Expand All @@ -32,3 +46,4 @@ def generate(self, **kwargs):
return answers
else:
logger.error(f"Failed to call generator for {self.generator.name}")
# self.llm_api.request(**kwargs)
2 changes: 1 addition & 1 deletion src/airunner/aihandler/mixins/compel_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def clear_prompt_embeds(self):
def load_prompt_embeds(self):
logger.info("Loading prompt embeds")
self.compel_proc = None
self.clear_memory()
self.engine.clear_memory()
prompt = self.prompt
negative_prompt = self.negative_prompt if self.negative_prompt else ""

Expand Down
2 changes: 1 addition & 1 deletion src/airunner/aihandler/mixins/memory_efficient_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def apply_tome(self):
self.remove_tome_sd()

def apply_tome_sd(self):
logger.info("Applying ToMe SD weight merging with ratio {self.tome_sd_ratio}")
logger.info(f"Applying ToMe SD weight merging with ratio {self.tome_sd_ratio}")
tomesd.apply_patch(self.pipe, ratio=self.tome_sd_ratio)
self.tome_sd_applied = True
self.tome_ratio = self.tome_sd_ratio
Expand Down
2 changes: 1 addition & 1 deletion src/airunner/aihandler/mixins/scheduler_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def clear_scheduler(self):
self.current_scheduler_name = None

def load_scheduler(self, force_scheduler_name=None, config=None):
if self.use_kandinsky:
if self.use_kandinsky or self.is_sd_xl_turbo:
return None

import diffusers
Expand Down
Loading