Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD XL support #11757

Merged
merged 25 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
af08121
getting SD2.1 to run on SDXL repo
AUTOMATIC1111 Jul 11, 2023
da464a3
SDXL support
AUTOMATIC1111 Jul 12, 2023
60397a7
Merge branch 'dev' into sdxl
AUTOMATIC1111 Jul 12, 2023
5cf623c
linter
AUTOMATIC1111 Jul 12, 2023
a04c955
fix importlib.machinery issue on github's autotests #yolo
AUTOMATIC1111 Jul 12, 2023
b717eb7
mute unneeded SDXL imports for tests too
AUTOMATIC1111 Jul 13, 2023
ac4ccfa
get attention optimizations to work
AUTOMATIC1111 Jul 13, 2023
21aec6f
lint
AUTOMATIC1111 Jul 13, 2023
594c8e7
fix CLIP doing the unneeded normalization
AUTOMATIC1111 Jul 13, 2023
76ebb17
lora support
AUTOMATIC1111 Jul 13, 2023
6f23da6
fix broken img2img
AUTOMATIC1111 Jul 13, 2023
b8159d0
add XL support for live previews: approx and TAESD
AUTOMATIC1111 Jul 13, 2023
e16ebc9
repair --no-half for SDXL
AUTOMATIC1111 Jul 13, 2023
ff73841
mute SDXL imports in the place there SDXL is imported for the first t…
AUTOMATIC1111 Jul 13, 2023
6c5f83b
add support for SDXL loras with te1/te2 modules
AUTOMATIC1111 Jul 13, 2023
dc39061
thank you linter
AUTOMATIC1111 Jul 13, 2023
6d8dcde
initial SDXL refiner support
AUTOMATIC1111 Jul 14, 2023
b7dbeda
linter
AUTOMATIC1111 Jul 14, 2023
abb948d
raise maximum Negative Guidance minimum sigma due to request in PR di…
AUTOMATIC1111 Jul 14, 2023
9a3f35b
repair medvram and lowvram
AUTOMATIC1111 Jul 14, 2023
92a3236
Merge branch 'dev' into sdxl
AUTOMATIC1111 Jul 14, 2023
471a5a6
add more relevant fields to caching conds
AUTOMATIC1111 Jul 14, 2023
ac2d47f
add cheap VAE approximation coeffs for SDXL
AUTOMATIC1111 Jul 14, 2023
5dee0fa
add a message about unsupported samplers
AUTOMATIC1111 Jul 14, 2023
14cf434
fix an issue in live previews that happens when you use SDXL with fp1…
AUTOMATIC1111 Jul 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions extensions-builtin/Lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def match(match_list, regex_text):

return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"

if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
if 'mlp_fc1' in m[1]:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

return key


Expand Down Expand Up @@ -142,10 +150,20 @@ def __init__(self):
def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {}

for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
if shared.sd_model.is_sdxl:
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
if not hasattr(embedder, 'wrapped'):
continue

for name, module in embedder.wrapped.named_modules():
lora_name = f'{i}_{name.replace(".", "_")}'
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name

for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_")
Expand All @@ -168,19 +186,27 @@ def load_lora(name, lora_on_disk):
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping

for key_diffusers, weight in sd.items():
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
for key_lora, weight in sd.items():
key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)

key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)

if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)

# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)

if sd_module is None:
keys_failed_to_match[key_diffusers] = key
keys_failed_to_match[key_lora] = key
continue

lora_module = lora.modules.get(key, None)
Expand All @@ -203,9 +229,9 @@ def load_lora(name, lora_on_disk):
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")

with torch.no_grad():
module.weight.copy_(weight)
Expand All @@ -217,7 +243,7 @@ def load_lora(name, lora_on_disk):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")

if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
Expand Down
2 changes: 1 addition & 1 deletion modules/hypernetworks/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
return context_k, context_v


def attention_CrossAttention_forward(self, x, context=None, mask=None):
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads

q = self.to_q(x)
Expand Down
4 changes: 4 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,13 @@ def prepare_environment():
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")

stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')

stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
Expand Down Expand Up @@ -297,6 +299,7 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)

git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
Expand All @@ -321,6 +324,7 @@ def prepare_environment():
exit(0)



def configure_for_tests():
if "--api" not in sys.argv:
sys.argv.append("--api")
Expand Down
53 changes: 39 additions & 14 deletions modules/lowvram.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,56 @@ def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)

# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model

# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
to_remain_in_cpu = [
(sd_model, 'first_stage_model'),
(sd_model, 'depth_model'),
(sd_model, 'embedder'),
(sd_model, 'model'),
(sd_model, 'embedder'),
]

is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')

if is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
else:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))

# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
stored = []
for obj, field in to_remain_in_cpu:
module = getattr(obj, field, None)
stored.append(module)
setattr(obj, field, None)

# send the model to GPU.
sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored

# put modules back. the modules will be in CPU.
for (obj, field), module in zip(to_remain_in_cpu, stored):
setattr(obj, field, module)

# register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
if is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)

sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
del sd_model.cond_stage_model.transformer
if hasattr(sd_model, 'cond_stage_model'):
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
Expand Down
25 changes: 25 additions & 0 deletions modules/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@
import modules.safe # noqa: F401


def mute_sdxl_imports():
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""

class Dummy:
pass

module = Dummy()
module.LPIPS = None
sys.modules['taming.modules.losses.lpips'] = module

module = Dummy()
module.StableDataModuleFromConfig = None
sys.modules['sgm.data'] = module


# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)

Expand All @@ -18,8 +33,11 @@

assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"

mute_sdxl_imports()

path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
Expand All @@ -35,6 +53,13 @@
d = os.path.abspath(d)
if "atstart" in options:
sys.path.insert(0, d)
elif "sgm" in options:
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.

sys.path.insert(0, d)
import sgm # noqa: F401
sys.path.pop(0)
else:
sys.path.append(d)
paths[what] = d
27 changes: 21 additions & 6 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,39 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr

caches is a list with items described above.
"""

cached_params = (
required_prompts,
steps,
opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info,
extra_network_data,
opts.sdxl_crop_left,
opts.sdxl_crop_top,
self.width,
self.height,
)

for cache in caches:
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
if cache[0] is not None and cached_params == cache[0]:
return cache[1]

cache = caches[0]

with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)

cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
cache[0] = cached_params
return cache[1]

def setup_conds(self):
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)

sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)

def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
Expand Down Expand Up @@ -523,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see


def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
x = model.decode_first_stage(x)
x = model.decode_first_stage(x.to(devices.dtype_vae))

return x

Expand Down
Loading