Skip to content

Commit

Permalink
fix #8, support switching between AlwaysVisible and Script
Browse files Browse the repository at this point in the history
  • Loading branch information
Kahsolt committed Mar 9, 2023
1 parent 3327ccc commit ff28cfd
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 52 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# meta
.vscode/
__pycache__/

# local config
config.json
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ to get image latents with higher quality (~perhaps!), and just pray again for go

⚪ Features

- 2023/03/09: switch between two morphs (as AlwaysVisible for working with other scripts, as Script for supporting auto grid search)
- 2023/03/08: add grid search (free your hands!!)
- 2023/01/28: add upcale (issue #3)
- 2023/01/12: remove gradient-related functionality due to webui code change
Expand Down
208 changes: 156 additions & 52 deletions scripts/sonar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pathlib import Path
import random
import json
from contextlib import nullcontext
from PIL import Image
from PIL.Image import Image as PILImage
from pprint import pprint as pp
Expand All @@ -23,7 +26,31 @@
from typing import List, Tuple, Union, Literal
from torch import Tensor

# local config
SD_WEBUI_PATH = Path.cwd()
SONAR_PATH = SD_WEBUI_PATH / 'extensions' / 'stable-diffusion-webui-sonar'
CONFIG_FILE = SONAR_PATH / 'config.json'

def load_cfg():
with open(CONFIG_FILE, 'r', encoding='utf-8') as fh:
return json.load(fh)
def save_cfg():
global cfg
with open(CONFIG_FILE, 'w', encoding='utf-8') as fh:
json.dump(cfg, fh, ensure_ascii=False, indent=2)

if CONFIG_FILE.exists():
cfg = load_cfg()
else:
cfg = { 'is_scripts': True } ; save_cfg()

if 'global const':
IS_SCRIPTS: bool = cfg.get('is_scripts', True)
LABEL_MORPH = {
True: 'Move to AlwaysVisible panel at next webui reload',
False: 'Move to Script dropdown at next webui reload',
}

DEFAULT_ENABLE = False
DEFAULT_UPSCALE = False
DEFAULT_GRID_SEARCH = False
Expand Down Expand Up @@ -588,62 +615,91 @@ def describe(self):
return "Wrapped samplers with tricks to optimize prompt condition and image latent for better image quality"

def show(self, is_img2img):
return True
if IS_SCRIPTS: return True
else: return scripts.AlwaysVisible

def ui(self, is_img2img):
with gr.Row(variant='compact').style(equal_height=True):
sampler = gr.Radio(label='Base Sampler', value=lambda: DEFAULT_SAMPLER, choices=CHOICE_SAMPLER)
use_upscale = gr.Checkbox(label='Upscaling', value=lambda: DEFAULT_UPSCALE)
use_grid_search = gr.Checkbox(label='Grid search', value=lambda: DEFAULT_GRID_SEARCH)

with gr.Group() as tab_momentum:
with gr.Row(variant='compact', visible=not DEFAULT_GRID_SEARCH) as tab_m:
momentum = gr.Slider(label='Momentum current', minimum=0.75, maximum=1.0, value=lambda: DEFAULT_MOMENTUM)
momentum_hist = gr.Slider(label='Momentum history', minimum=0.0, maximum=1.0, value=lambda: DEFAULT_MOMENTUM_HIST)
with gr.Row(variant='compact', visible=DEFAULT_GRID_SEARCH) as tab_m_gs:
momentum_gs = gr.Text(label='Momentum current search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_GS)
momentum_hist_gs = gr.Text(label='Momentum history search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_HIST_GS)
with gr.Row(variant='compact'):
momentum_sign = gr.Radio(label='Momentum sign (experimental)', value=lambda: DEFAULT_MOMENTUM_SIGN, choices=CHOICE_MOMENTUM_SIGN)
momentum_hist_init = gr.Radio(label='Momentum history init (experimental)', value=lambda: DEFAULT_MOMENTUM_HIST_INIT, choices=CHOICE_MOMENTUM_HIST_INIT)
use_grid_search.change(fn=lambda x: [gr_show(not x), gr_show(x)], inputs=use_grid_search, outputs=[tab_m, tab_m_gs])

with gr.Group(visible=False) as tab_file:
with gr.Row(variant='compact'):
ref_meth = gr.Radio(label='Ref guide step method', value=lambda: DEFAULT_REF_METH, choices=CHOICE_REF_METH)
ref_hgf = gr.Slider(label='Ref guide factor', value=lambda: DEFAULT_REF_HGF, minimum=0, maximum=1, step=0.01)
ref_min_step = gr.Number(label='Ref start step', value=lambda: DEFAULT_REF_MIN_STEP)
ref_max_step = gr.Number(label='Ref stop step', value=lambda: DEFAULT_REF_MAX_STEP)
with gr.Row(variant='compact'):
ref_img = gr.Image(label='Reference image file', image_mode=None, type='pil')

def swith_sampler(sampler:str):
SHOW_TABS = {
# (show_momt, show_file)
'Euler a': (True, False),
'Euler': (True, False),
'Naive': (True, True),
}
show_momt, show_file = SHOW_TABS[sampler]
with gr.Accordion(label='Sonar sampler', open=False) if not IS_SCRIPTS else nullcontext():
if not IS_SCRIPTS:
is_enable = gr.Checkbox(label='Enable', value=lambda: DEFAULT_ENABLE)

with gr.Row(variant='compact').style(equal_height=True):
sampler = gr.Radio(label='Base Sampler', value=lambda: DEFAULT_SAMPLER, choices=CHOICE_SAMPLER)
use_upscale = gr.Checkbox(label='Upscaling', value=lambda: DEFAULT_UPSCALE)

if IS_SCRIPTS:
use_grid_search = gr.Checkbox(label='Grid search', value=lambda: DEFAULT_GRID_SEARCH)

with gr.Group() as tab_momentum:
with gr.Row(variant='compact', visible=not DEFAULT_GRID_SEARCH) as tab_m:
momentum = gr.Slider(label='Momentum current', minimum=0.75, maximum=1.0, value=lambda: DEFAULT_MOMENTUM)
momentum_hist = gr.Slider(label='Momentum history', minimum=0.0, maximum=1.0, value=lambda: DEFAULT_MOMENTUM_HIST)

if IS_SCRIPTS:
with gr.Row(variant='compact', visible=DEFAULT_GRID_SEARCH) as tab_m_gs:
momentum_gs = gr.Text(label='Momentum current search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_GS)
momentum_hist_gs = gr.Text(label='Momentum history search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_HIST_GS)

use_grid_search.change(fn=lambda x: [gr_show(not x), gr_show(x)], inputs=use_grid_search, outputs=[tab_m, tab_m_gs])

with gr.Row(variant='compact'):
momentum_sign = gr.Radio(label='Momentum sign (experimental)', value=lambda: DEFAULT_MOMENTUM_SIGN, choices=CHOICE_MOMENTUM_SIGN)
momentum_hist_init = gr.Radio(label='Momentum history init (experimental)', value=lambda: DEFAULT_MOMENTUM_HIST_INIT, choices=CHOICE_MOMENTUM_HIST_INIT)

with gr.Group(visible=False) as tab_file:
with gr.Row(variant='compact'):
ref_meth = gr.Radio(label='Ref guide step method', value=lambda: DEFAULT_REF_METH, choices=CHOICE_REF_METH)
ref_hgf = gr.Slider(label='Ref guide factor', value=lambda: DEFAULT_REF_HGF, minimum=0, maximum=1, step=0.01)
ref_min_step = gr.Number(label='Ref start step', value=lambda: DEFAULT_REF_MIN_STEP)
ref_max_step = gr.Number(label='Ref stop step', value=lambda: DEFAULT_REF_MAX_STEP)

with gr.Row(variant='compact'):
ref_img = gr.Image(label='Reference image file', image_mode=None, type='pil')

def swith_sampler(sampler:str):
SHOW_TABS = {
# (show_momt, show_file)
'Euler a': (True, False),
'Euler': (True, False),
'Naive': (True, True),
}
show_momt, show_file = SHOW_TABS[sampler]
return [
gr_show(show_momt),
gr_show(show_file),
]
sampler.change(swith_sampler, inputs=[sampler], outputs=[tab_momentum, tab_file])

with gr.Row(variant='compact', visible=False) as tab_upscale:
upscale_meth = gr.Dropdown(label='Upscaler', value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICE_UPSCALER)
upscale_ratio = gr.Slider (label='Upscale ratio', value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=4.0, step=0.1)
upscale_width = gr.Slider (label='Upscale width', value=lambda: DEFAULT_UPSCALE_W, minimum=0, maximum=2048, step=8)
upscale_height = gr.Slider (label='Upscale height', value=lambda: DEFAULT_UPSCALE_H, minimum=0, maximum=2048, step=8)
use_upscale.change(fn=lambda x: gr_show(x), inputs=use_upscale, outputs=tab_upscale, show_progress=False)

with gr.Row():
switch_morph = gr.Checkbox(label=LABEL_MORPH[IS_SCRIPTS])
def switch_morph_fn():
global IS_SCRIPTS
IS_SCRIPTS = not IS_SCRIPTS
cfg['is_scripts'] = IS_SCRIPTS ; save_cfg()
switch_morph.change(fn=switch_morph_fn, show_progress=False)

if IS_SCRIPTS:
return [
use_upscale, use_grid_search, sampler,
momentum, momentum_hist, momentum_gs, momentum_hist_gs, momentum_hist_init, momentum_sign,
ref_meth, ref_hgf, ref_min_step, ref_max_step, ref_img,
upscale_meth, upscale_ratio, upscale_width, upscale_height
]
else:
return [
gr_show(show_momt),
gr_show(show_file),
is_enable,
use_upscale, sampler,
momentum, momentum_hist, momentum_hist_init, momentum_sign,
ref_meth, ref_hgf, ref_min_step, ref_max_step, ref_img,
upscale_meth, upscale_ratio, upscale_width, upscale_height
]
sampler.change(swith_sampler, inputs=[sampler], outputs=[tab_momentum, tab_file])

with gr.Row(variant='compact', visible=False) as tab_upscale:
upscale_meth = gr.Dropdown(label='Upscaler', value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICE_UPSCALER)
upscale_ratio = gr.Slider (label='Upscale ratio', value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=4.0, step=0.1)
upscale_width = gr.Slider (label='Upscale width', value=lambda: DEFAULT_UPSCALE_W, minimum=0, maximum=2048, step=8)
upscale_height = gr.Slider (label='Upscale height', value=lambda: DEFAULT_UPSCALE_H, minimum=0, maximum=2048, step=8)
use_upscale.change(fn=lambda x: gr_show(x), inputs=use_upscale, outputs=tab_upscale, show_progress=False)

return [
use_upscale, use_grid_search, sampler,
momentum, momentum_hist, momentum_gs, momentum_hist_gs, momentum_hist_init, momentum_sign,
ref_meth, ref_hgf, ref_min_step, ref_max_step, ref_img,
upscale_meth, upscale_ratio, upscale_width, upscale_height
]

def run(self, p:StableDiffusionProcessing,
use_upscale:bool, use_grid_search:bool, sampler:str,
Expand Down Expand Up @@ -731,3 +787,51 @@ def save_image_hijack(params:ImageSaveParams):
p.sample = self.sample_saved
remove_callbacks_for_function(save_image_hijack)
return Processed(p, imgs, p.seed, info)

def process(self, p:StableDiffusionProcessing,
is_enable:bool,
use_upscale:bool, sampler:str,
momentum:float, momentum_hist:float, momentum_hist_init:str, momentum_sign:str,
ref_meth:str, ref_hgf:float, ref_min_step:float, ref_max_step:float, ref_img:PILImage,
upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int
):
if not is_enable: return

# type convert
if ref_img is not None: ref_img = resize_image(1, ref_img, p.width, p.height)

# save settings to global
settings['sampler'] = sampler
settings['momentum'] = momentum
settings['momentum_hist'] = momentum_hist
settings['momentum_hist_init'] = momentum_hist_init
settings['momentum_sign'] = momentum_sign
settings['ref_meth'] = ref_meth
settings['ref_min_step'] = int(ref_min_step) if ref_min_step > 1 else round(ref_min_step * p.steps)
settings['ref_max_step'] = int(ref_max_step) if ref_max_step > 1 else round(ref_max_step * p.steps)
settings['ref_hgf'] = ref_hgf / 10.0
settings['ref_img'] = ref_img

#pp(settings)

if use_upscale:
need_upscale, (tgt_w, tgt_h) = get_upscale_resolution(p, upscale_meth, upscale_ratio, upscale_width, upscale_height)
if need_upscale: print(f'>> upscale: ({p.width}, {p.height}) => ({tgt_w}, {tgt_h})')

def save_image_hijack(params:ImageSaveParams):
if use_upscale and need_upscale:
params.image = resize_image(1, params.image, tgt_w, tgt_h, upscaler_name=upscale_meth)
on_before_image_saved(save_image_hijack)
self.save_image_hijack = save_image_hijack

self.sample_saved = p.sample
if isinstance(p, StableDiffusionProcessingTxt2Img):
p.sample = lambda *args, **kwargs: StableDiffusionProcessingTxt2Img_sample(p, *args, **kwargs)
elif isinstance(p, StableDiffusionProcessingImg2Img):
p.sample = lambda *args, **kwargs: StableDiffusionProcessingImg2Img_sample(p, *args, **kwargs)

def postprocess(self, p:StableDiffusionProcessing, processed:Processed, is_enable:bool, *args):
if not is_enable: return

p.sample = self.sample_saved
remove_callbacks_for_function(self.save_image_hijack)

0 comments on commit ff28cfd

Please sign in to comment.