From 33856def7c93a45d6096437db572700c62756d44 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 10:15:39 -0500 Subject: [PATCH] Next: Switch SDXLPromptInvocationBase to read TI names as model keys rather than model_name --- invokeai/app/invocations/compel.py | 8 +++----- invokeai/app/util/ti_utils.py | 8 +++++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 50f53225137..1586065b7d2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -193,11 +193,9 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.models.load_by_attrs( - model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion - ).model - assert isinstance(ti_model, TextualInversionModelRaw) - ti_list.append((name, ti_model)) + loaded_model = context.models.load(key=name).model + assert isinstance(loaded_model, TextualInversionModelRaw) + ti_list.append((name, loaded_model)) except UnknownModelException: # print(e) # import traceback diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py index a66a832b42a..7e25d07a828 100644 --- a/invokeai/app/util/ti_utils.py +++ b/invokeai/app/util/ti_utils.py @@ -1,8 +1,10 @@ import re +from typing import List -def extract_ti_triggers_from_prompt(prompt: str) -> list[str]: - ti_triggers = [] + +def extract_ti_triggers_from_prompt(prompt: str) -> List[str]: + ti_triggers: List[str] = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): - ti_triggers.append(trigger) + ti_triggers.append(str(trigger)) return ti_triggers