Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _lora_loader():


class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
Expand Down Expand Up @@ -210,8 +210,8 @@ def _lora_loader():
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')

with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
Expand Down Expand Up @@ -247,7 +247,7 @@ def _lora_loader():

return c, c_pooled, None

def run_clip_compel(self, context, clip_field, prompt, get_pooled):
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
Expand Down Expand Up @@ -284,8 +284,8 @@ def _lora_loader():
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')

with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
Expand Down Expand Up @@ -357,11 +357,11 @@ class Config(InvocationConfig):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
Expand Down Expand Up @@ -415,7 +415,8 @@ class Config(InvocationConfig):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
Expand Down Expand Up @@ -467,11 +468,11 @@ class Config(InvocationConfig):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
Expand Down Expand Up @@ -525,7 +526,8 @@ class Config(InvocationConfig):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings

from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
Expand Down
97 changes: 97 additions & 0 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,103 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
return output


class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""

# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"

unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
# fmt: on


class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""

type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"

lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")

unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")

class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}

def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")

base_model = self.lora.base_model
lora_name = self.lora.model_name

if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")

if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')

if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')

if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')

output = SDXLLoraLoaderOutput()

if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)

if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)

if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)

return output


class VAEModelField(BaseModel):
"""Vae model field"""

Expand Down
26 changes: 23 additions & 3 deletions invokeai/app/invocations/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import Field, validator

from ...backend.model_management import ModelType, SubModelType
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext

Expand Down Expand Up @@ -293,10 +293,20 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

num_inference_steps = self.steps

def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return

unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps

Expand Down Expand Up @@ -543,9 +553,19 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
context=context,
)

def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return

do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet.device)
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_management/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod
from .lora import ModelPatcher
Loading