Skip to content

Commit

Permalink
Add warning when meet emb name conflicting
Browse files Browse the repository at this point in the history
Choose standalone embedding (in /embeddings folder) first
  • Loading branch information
KohakuBlueleaf committed Oct 10, 2023
1 parent 2282eb8 commit 81e94de
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 32 deletions.
33 changes: 33 additions & 0 deletions extensions-builtin/Lora/lora_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sys
import copy
import logging


class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[0;36m", # CYAN
"INFO": "\033[0;32m", # GREEN
"WARNING": "\033[0;33m", # YELLOW
"ERROR": "\033[0;31m", # RED
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
"RESET": "\033[0m", # RESET COLOR
}

def format(self, record):
colored_record = copy.copy(record)
levelname = colored_record.levelname
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
return super().format(colored_record)


logger = logging.getLogger("lora")
logger.propagate = False


if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
)
logger.addHandler(handler)
80 changes: 48 additions & 32 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding

from lora_logger import logger

module_types = [
network_lora.ModuleTypeLora(),
network_hada.ModuleTypeHada(),
Expand Down Expand Up @@ -206,7 +208,40 @@ def load_network(name, network_on_disk):

net.modules[key] = net_module

net.bundle_embeddings = bundle_embeddings
embeddings = {}
for emb_name, data in bundle_embeddings.items():
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, emb_name)
embedding.vectors = vectors
embedding.shape = shape
embedding.loaded = None
embeddings[emb_name] = embedding

net.bundle_embeddings = embeddings

if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
Expand All @@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
for emb_name in net.bundle_embeddings:
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
for emb_name, embedding in net.bundle_embeddings.items():
if embedding.loaded:
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)

loaded_networks.clear()

Expand Down Expand Up @@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)

for emb_name, data in net.bundle_embeddings.items():
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, emb_name)
embedding.vectors = vectors
embedding.shape = shape
for emb_name, embedding in net.bundle_embeddings.items():
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
logger.warning(
f'Skip bundle embedding: "{emb_name}"'
' as it was already loaded from embeddings folder'
)
continue

embedding.loaded = False
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
embedding.loaded = True
emb_db.register_embedding(embedding, shared.sd_model)
else:
emb_db.skipped_embeddings[name] = embedding
Expand Down

0 comments on commit 81e94de

Please sign in to comment.