Skip to content

Commit

Permalink
Merge branch 'dev' into multiple_loaded_models
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Aug 5, 2023
2 parents 390bffa + 0ae2767 commit 22ecb78
Show file tree
Hide file tree
Showing 45 changed files with 924 additions and 476 deletions.
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/ui_edit_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def create_editor(self):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)

with gr.Column(scale=1, min_width=120):
generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)

self.edit_notes = gr.TextArea(label='Notes', lines=4)

Expand Down
10 changes: 5 additions & 5 deletions javascript/localization.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION',
img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
extras_upscaler_1: 'SPAN',
extras_upscaler_2: 'SPAN',
setting_random_artist_categories: 'OPTION',
setting_face_restoration_model: 'OPTION',
setting_realesrgan_enabled_models: 'OPTION',
extras_upscaler_1: 'OPTION',
extras_upscaler_2: 'OPTION',
};

var re_num = /^[.\d]+$/;
Expand Down
2 changes: 2 additions & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
83 changes: 75 additions & 8 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import lru_cache

import torch
from modules import errors
from modules import errors, rng_philox

if sys.platform == "darwin":
from modules import mac_specific
Expand Down Expand Up @@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True



errors.run(enable_tf32, "Enabling TF32")

cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
cpu: torch.device = torch.device("cpu")
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
device_esrgan: torch.device = None
device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False


Expand All @@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input


nv_rng = None


def randn(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""

from modules.shared import opts

torch.manual_seed(seed)
manual_seed(seed)

if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)

if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)

return torch.randn(shape, device=device)


def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""

from modules.shared import opts

if opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=device)

local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(device)


def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""

from modules.shared import opts

if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)

if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=cpu).to(x.device)

return torch.randn_like(x)


def randn_without_seed(shape):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""

from modules.shared import opts

if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)

if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)

return torch.randn(shape, device=device)


def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts

if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return

torch.manual_seed(seed)


def autocast(disable=False):
from modules import shared

Expand Down
50 changes: 50 additions & 0 deletions modules/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,53 @@ def run(code, task):
code()
except Exception as e:
display(task, e)


def check_versions():
from packaging import version
from modules import shared

import torch
import gradio

expected_torch_version = "2.0.0"
expected_xformers_version = "0.0.20"
expected_gradio_version = "3.39.0"

if version.parse(torch.__version__) < version.parse(expected_torch_version):
print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded, as well as
there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())

if shared.xformers_available:
import xformers

if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
Use --skip-version-check commandline argument to disable this check.
""".strip())

if gradio.__version__ != expected_gradio_version:
print_error_explanation(f"""
You are running gradio {gradio.__version__}.
The program is designed to work with gradio {expected_gradio_version}.
Using a different version of gradio is extremely likely to break the program.
Reasons why you have the mismatched gradio version can be:
- you use --skip-install flag.
- you use webui.py to start the program instead of launch.py.
- an extension installs the incompatible gradio version.
Use --skip-version-check commandline argument to disable this check.
""".strip())

10 changes: 7 additions & 3 deletions modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


def active():
if shared.opts.disable_all_extensions == "all":
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
return []
elif shared.opts.disable_all_extensions == "extra":
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
Expand Down Expand Up @@ -141,8 +141,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return

if shared.opts.disable_all_extensions == "all":
if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
elif shared.cmd_opts.disable_extra_extensions:
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")

Expand Down
19 changes: 19 additions & 0 deletions modules/extra_networks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
import re
from collections import defaultdict

Expand Down Expand Up @@ -177,3 +179,20 @@ def parse_prompts(prompts):

return res, extra_data


def get_user_metadata(filename):
if filename is None:
return {}

basename, ext = os.path.splitext(filename)
metadata_filename = basename + '.json'

metadata = {}
try:
if os.path.isfile(metadata_filename):
with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file)
except Exception as e:
errors.display(e, f"reading extra network user metadata from {metadata_filename}")

return metadata
3 changes: 3 additions & 0 deletions modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def parse_generation_parameters(x: str):
if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler"

if "Hires checkpoint" not in res:
res["Hires checkpoint"] = "Use same checkpoint"

if "Hires prompt" not in res:
res["Hires prompt"] = ""

Expand Down
60 changes: 60 additions & 0 deletions modules/gradio_extensons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import gradio as gr

from modules import scripts

def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""

comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]

if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')


def IOComponent_init(self, *args, **kwargs):
self.webui_tooltip = kwargs.pop('tooltip', None)

if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)

scripts.script_callbacks.before_component_callback(self, **kwargs)

res = original_IOComponent_init(self, *args, **kwargs)

add_classes_to_gradio_component(self)

scripts.script_callbacks.after_component_callback(self, **kwargs)

if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)

return res


def Block_get_config(self):
config = original_Block_get_config(self)

webui_tooltip = getattr(self, 'webui_tooltip', None)
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip

return config


def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)

add_classes_to_gradio_component(self)

return res


original_IOComponent_init = gr.components.IOComponent.__init__
original_Block_get_config = gr.blocks.Block.get_config
original_BlockContext_init = gr.blocks.BlockContext.__init__

gr.components.IOComponent.__init__ = IOComponent_init
gr.blocks.Block.get_config = Block_get_config
gr.blocks.BlockContext.__init__ = BlockContext_init
5 changes: 2 additions & 3 deletions modules/hypernetworks/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
Expand Down Expand Up @@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,


def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
from modules import images, processing

save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
Expand Down
2 changes: 1 addition & 1 deletion modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def resize(im, w, h):
return res


invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
Expand Down
6 changes: 2 additions & 4 deletions modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr

from modules import sd_samplers, images as imgutil
Expand Down Expand Up @@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
mask = None
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
Expand Down
Loading

0 comments on commit 22ecb78

Please sign in to comment.