Skip to content

Commit

Permalink
change StableDiffusionProcessing to internally use sampler name inste…
Browse files Browse the repository at this point in the history
…ad of sampler index
  • Loading branch information
AUTOMATIC1111 committed Nov 19, 2022
1 parent d9fd452 commit cdc8020
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 54 deletions.
26 changes: 9 additions & 17 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
Expand All @@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")


sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")

return name

def setUpscalers(req: dict):
reqDict = vars(req)
Expand Down Expand Up @@ -82,14 +86,9 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])

def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)

if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")

populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
Expand All @@ -109,12 +108,6 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)

if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")


init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
Expand All @@ -123,10 +116,9 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
if mask:
mask = decode_base64_to_image(mask)


populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
Expand Down Expand Up @@ -272,7 +264,7 @@ def get_cmd_flags(self):
return vars(shared.cmd_opts)

def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]

def get_upscalers(self):
upscalers = []
Expand Down
4 changes: 2 additions & 2 deletions modules/hypernetworks/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
Expand Down Expand Up @@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
Expand Down
2 changes: 1 addition & 1 deletion modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class FilenameGenerator:
'width': lambda self: self.image.width,
'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
Expand Down
4 changes: 2 additions & 2 deletions modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from PIL import Image, ImageOps, ImageChops

from modules import devices
from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
Expand Down Expand Up @@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index,
sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
Expand Down
29 changes: 12 additions & 17 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import os
import sys
import warnings

import torch
import numpy as np
Expand Down Expand Up @@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):

return image

def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers

class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
if sampler_index is not None:
warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name")

self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
Expand All @@ -91,7 +88,7 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_index: int = sampler_index
self.sampler_name: str = sampler_name
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
Expand Down Expand Up @@ -210,8 +207,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="",
self.info = info
self.width = p.width
self.height = p.height
self.sampler_index = p.sampler_index
self.sampler = sd_samplers.samplers[p.sampler_index].name
self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
Expand Down Expand Up @@ -256,8 +252,7 @@ def js(self):
"subseed_strength": self.subseed_strength,
"width": self.width,
"height": self.height,
"sampler_index": self.sampler_index,
"sampler": self.sampler,
"sampler_name": self.sampler_name,
"cfg_scale": self.cfg_scale,
"steps": self.steps,
"batch_size": self.batch_size,
Expand Down Expand Up @@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration

generation_params = {
"Steps": p.steps,
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
"Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
Expand Down Expand Up @@ -645,7 +640,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
Expand Down Expand Up @@ -706,7 +701,7 @@ def save_intermediate(image, index):

shared.state.nextjob()

self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

Expand Down Expand Up @@ -743,7 +738,7 @@ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strengt
self.image_conditioning = None

def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None

if self.image_mask is not None:
Expand Down
13 changes: 10 additions & 3 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,23 @@
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
all_samplers_map = {x.name: x for x in all_samplers}

samplers = []
samplers_for_img2img = []


def create_sampler_with_index(list_of_configs, index, model):
config = list_of_configs[index]
def create_sampler(name, model):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]

assert config is not None, f'bad sampler name: {name}'

sampler = config.constructor(model)
sampler.config = config

return sampler


Expand Down
4 changes: 2 additions & 2 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from PIL import Image, PngImagePlugin

from modules import shared, devices, sd_hijack, processing, sd_models, images
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler

Expand Down Expand Up @@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
Expand Down
3 changes: 2 additions & 1 deletion modules/txt2img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import modules.scripts
from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
Expand All @@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index,
sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
Expand Down
2 changes: 1 addition & 1 deletion modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self, d=None):
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)

writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])

# Make Zip
if do_make_zip:
Expand Down
4 changes: 2 additions & 2 deletions scripts/img2imgalt.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def ui(self, is_img2img):
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override
if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
p.sampler_name = "Euler"
if override_prompt:
p.prompt = original_prompt
p.negative_prompt = original_negative_prompt
Expand Down Expand Up @@ -191,7 +191,7 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs

combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)

sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)

sigmas = sampler.model_wrap.get_sigmas(p.steps)

Expand Down
12 changes: 6 additions & 6 deletions scripts/xy_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import modules.scripts as scripts
import gradio as gr

from modules import images
from modules import images, sd_samplers
from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers
Expand Down Expand Up @@ -60,25 +60,25 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt


def build_samplers_dict(p):
def build_samplers_dict():
samplers_dict = {}
for i, sampler in enumerate(get_correct_sampler(p)):
for i, sampler in enumerate(sd_samplers.all_samplers):
samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases:
samplers_dict[alias.lower()] = i
return samplers_dict


def apply_sampler(p, x, xs):
sampler_index = build_samplers_dict(p).get(x.lower(), None)
sampler_index = build_samplers_dict().get(x.lower(), None)
if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}")

p.sampler_index = sampler_index

This comment has been minimized.

Copy link
@djdookie

djdookie Nov 19, 2022

This led to a new bug report #4860



def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p)
samplers_dict = build_samplers_dict()
for x in xs:
if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}")
Expand Down

0 comments on commit cdc8020

Please sign in to comment.