Skip to content

Commit

Permalink
Next: Switch SDXLPromptInvocationBase to read TI names as model keys …
Browse files Browse the repository at this point in the history
…rather than model_name
  • Loading branch information
brandonrising committed Feb 27, 2024
1 parent 4418c11 commit 33856de
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 3 additions & 5 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,9 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.models.load_by_attrs(
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))
loaded_model = context.models.load(key=name).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
# print(e)
# import traceback
Expand Down
8 changes: 5 additions & 3 deletions invokeai/app/util/ti_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re

from typing import List

def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
ti_triggers = []

def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:

Check failure on line 6 in invokeai/app/util/ti_utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

invokeai/app/util/ti_utils.py:1:1: I001 Import block is un-sorted or un-formatted
ti_triggers: List[str] = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
ti_triggers.append(trigger)
ti_triggers.append(str(trigger))
return ti_triggers

0 comments on commit 33856de

Please sign in to comment.