Skip to content

Commit

Permalink
code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Kahsolt committed Mar 9, 2023
1 parent 2af6d23 commit 3327ccc
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 131 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ How hard ref_img guidance works:
- momentum_hist: (float), momentum of memorized difference history
- the larger, approving current history
- the smaller, approving former histories
- momentum_hist_init: (float), init value of history, aka. the genesis
- momentum_hist_init: (categorical), init value of history, aka. the genesis (experimental)
- `zero`: use the first denoised latent
- `rand_init`: just use the init latent noise
- `rand_new`: use a new guassian noise
- momentum_sign: (categorical), momentum direction to apply correction
- momentum_sign: (categorical), momentum direction to apply correction (experimental)
- `pos`: correct by direction of history momentum, affirming the history
- `neg`: correct by opposite direction of history momentum, denying the history
- `rand`: random choose from above at each sampling step
Expand Down
223 changes: 94 additions & 129 deletions scripts/sonar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import random
from PIL import Image
from PIL.Image import Image as PILImage
from typing import List, Tuple
from pprint import pprint as pp
import inspect

import gradio as gr
import torch
from torch import Tensor
import numpy as np
from tqdm.auto import trange

Expand All @@ -22,6 +20,9 @@
from k_diffusion.sampling import to_d, get_ancestral_step
from ldm.models.diffusion.ddpm import LatentDiffusion

from typing import List, Tuple, Union, Literal
from torch import Tensor

if 'global const':
DEFAULT_ENABLE = False
DEFAULT_UPSCALE = False
Expand Down Expand Up @@ -248,26 +249,12 @@ def sample_naive_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal
ref_min_step = settings['ref_min_step']
ref_max_step = settings['ref_max_step']

# memorize delta momentum
if momentum_hist_init == 'zero': history_d = 0
elif momentum_hist_init == 'rand_init': history_d = x
elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x)
else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}')
# init hist momentum memory
history_d = init_hist_d(momentum_hist_init, x)

# prepare ref_img latent
if ref_img is not None:
x_ref = torch.from_numpy(np.asarray(ref_img)).moveaxis(2, 0) # [C=3, H, W]
x_ref = (x_ref / 255) * 2 - 1
x_ref = x_ref.unsqueeze(dim=0).expand(x.shape[0], -1, -1, -1) # [B, C=3, H, W]
x_ref = x_ref.to(sd_model.first_stage_model.device)

with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
latent_ref = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x_ref)) # [B, C=4, H=64, W=64]

avg_s = latent_ref.mean(dim=[2, 3], keepdim=True)
std_s = latent_ref.std (dim=[2, 3], keepdim=True)
ref_img_norm = (latent_ref - avg_s) / std_s
ref_img_norm = ref_img_norm.to(x_ref.dtype)
ref_img_norm = init_ref_img(sd_model, ref_img, x.shape[0])

s_in = x.new_ones([x.shape[0]])
n_steps = len(sigmas) - 1
Expand All @@ -284,50 +271,13 @@ def sample_naive_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal

# momentum step
if momentum < 1.0:
# decide correct direction
sign = momentum_sign
if sign == 'rand': sign = random.choice(['pos', 'neg'])

# correct current `d` with momentum
p = 1.0 - momentum
if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d
elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d
else: raise ValueError(f'unknown momentum sign {sign}')

# Euler method with momentum
x = x + momentum_d * dt

# update momentum history
q = 1.0 - momentum_hist
if (isinstance(history_d, int) and history_d == 0):
history_d = momentum_d
else:
if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d
elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d
else: raise ValueError(f'unknown momentum sign {sign}')
x, history_d = momentum_step(x, d, dt, momentum, momentum_hist, momentum_sign, history_d)
else:
# Euler method original
x = x + d * dt
x = x + d * dt # Euler method original

# guidance step
if ref_img is not None and ref_hgf and ref_min_step <= i <= ref_max_step:
# TODO: make scheduling for hgf?
if ref_meth == 'euler':
# rescale `ref_img` to match distribution
avg_t = denoised.mean(dim=[1, 2, 3], keepdim=True)
std_t = denoised.std (dim=[1, 2, 3], keepdim=True)
ref_img_shift = ref_img_norm * std_t + avg_t

d = to_d(x, sigmas[i], ref_img_shift)
dt = (sigmas[i + 1] - sigmas[i]) * ref_hgf
x = x + d * dt
if ref_meth == 'linear':
# rescale `ref_img` to match distribution
avg_t = x.mean(dim=[1, 2, 3], keepdim=True)
std_t = x.std (dim=[1, 2, 3], keepdim=True)
ref_img_shift = ref_img_norm * std_t + avg_t

x = (1 - ref_hgf) * x + ref_hgf * ref_img_shift
x = guidance_step(x, ref_meth, ref_img_norm, ref_hgf, denoised, sigmas, i)

# noise step alike ancestral
if i <= n_steps - 1:
Expand All @@ -342,10 +292,7 @@ def sample_euler_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal
momentum_hist = settings['momentum_hist']
momentum_hist_init = settings['momentum_hist_init']

if momentum_hist_init == 'zero': history_d = 0
elif momentum_hist_init == 'rand_init': history_d = x
elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x)
else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}')
history_d = init_hist_d(momentum_hist_init, x)

s_in = x.new_ones([x.shape[0]])
n_steps = len(sigmas) - 1
Expand All @@ -361,35 +308,10 @@ def sample_euler_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat

if 'momentum step':
# decide correct direction
sign = momentum_sign
action_pool = ['pos', 'neg']
if sign == 'rand' : sign = random.choice(action_pool)
elif sign == 'pos_neg': sign = action_pool[int(i < n_steps // 2)]
elif sign == 'neg_pos': sign = action_pool[int(i > n_steps // 2)]
else: pass

# correct current `d` with momentum
p = 1.0 - momentum
if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d
elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d
else: raise ValueError(f'unknown momentum sign {sign}')

# Euler method with momentum
x = x + momentum_d * dt

# update momentum history
q = 1.0 - momentum_hist
if (isinstance(history_d, int) and history_d == 0):
history_d = momentum_d
else:
if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d
elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d
else: raise ValueError(f'unknown momentum sign {sign}')
if momentum < 1.0:
x, history_d = momentum_step(x, d, dt, momentum, momentum_hist, momentum_sign, history_d)
else:
# Euler method original
x = x + d * dt
x = x + d * dt # Euler method original

return x

Expand All @@ -404,10 +326,7 @@ def sample_euler_ancestral_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_ar
show_featmap(x, 'before sample')

# 记录梯度历史的惯性
if momentum_hist_init == 'zero': history_d = 0
elif momentum_hist_init == 'rand_init': history_d = x
elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x)
else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}')
history_d = init_hist_d(momentum_hist_init, x)

s_in = x.new_ones([x.shape[0]]) # [B=1]
n_steps = len(sigmas) - 1
Expand All @@ -429,35 +348,10 @@ def sample_euler_ancestral_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_ar

# ancestral scheduling (down)
dt = sigma_down - sigmas[i] # scalar, 沿着梯度移动的步长
if 'momentum step':
# decide correct direction
sign = momentum_sign
action_pool = ['pos', 'neg']
if sign == 'rand' : sign = random.choice(action_pool)
elif sign == 'pos_neg': sign = action_pool[int(i < n_steps // 2)]
elif sign == 'neg_pos': sign = action_pool[int(i > n_steps // 2)]
else: pass

# correct current `d` with momentum
p = 1.0 - momentum
if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d
elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d
else: raise ValueError(f'unknown momentum sign {sign}')

# Euler method with momentum
x = x + momentum_d * dt

# update momentum history
q = 1.0 - momentum_hist
if (isinstance(history_d, int) and history_d == 0):
history_d = momentum_d
else:
if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d
elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d
else: raise ValueError(f'unknown momentum sign {sign}')
if momentum < 1.0:
x, history_d = momentum_step(x, d, dt, momentum, momentum_hist, momentum_sign, history_d)
else:
# Euler method original
x = x + d * dt
x = x + d * dt # Euler method original
show_featmap(x, f'x + d x dt (step {i}); sigma_down={sigma_down:.4f}') # 作画内容逐渐显露

# ancestral scheduling (up)
Expand Down Expand Up @@ -581,6 +475,77 @@ def sample_img2img(self, p:StableDiffusionProcessing, x:Tensor, noise:Tensor,
# ↑↑↑ the above is modified from 'modules/sd_samplers.py' ↑↑↑


def init_hist_d(momentum_hist_init:str, x:Tensor) -> Union[Literal[0], Tensor]:
# memorize delta momentum
if momentum_hist_init == 'zero': history_d = 0
elif momentum_hist_init == 'rand_init': history_d = x
elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x)
else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}')
return history_d

def momentum_step(x:Tensor, d:Tensor, dt:Tensor, momentum:float, momentum_hist:float, momentum_sign:str, history_d):
# decide correct direction
sign = momentum_sign
if sign == 'rand': sign = random.choice(['pos', 'neg'])

# correct current `d` with momentum
p = 1.0 - momentum
if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d
elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d
else: raise ValueError(f'unknown momentum sign {sign}')

# Euler method with momentum
x = x + momentum_d * dt

# update momentum history
q = 1.0 - momentum_hist
if (isinstance(history_d, int) and history_d == 0):
history_d = momentum_d
else:
if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d
elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d
else: raise ValueError(f'unknown momentum sign {sign}')

return x, history_d

def init_ref_img(sd_model:LatentDiffusion, ref_img:PILImage, B:int) -> Tensor:
x_ref = torch.from_numpy(np.asarray(ref_img)).moveaxis(2, 0) # [C=3, H, W]
x_ref = (x_ref / 255) * 2 - 1
x_ref = x_ref.unsqueeze(dim=0).expand(B, -1, -1, -1) # [B, C=3, H, W]
x_ref = x_ref.to(sd_model.first_stage_model.device)

with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
latent_ref = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x_ref)) # [B, C=4, H=64, W=64]

avg_s = latent_ref.mean(dim=[2, 3], keepdim=True)
std_s = latent_ref.std (dim=[2, 3], keepdim=True)
ref_img_norm = (latent_ref - avg_s) / std_s
ref_img_norm = ref_img_norm.to(x_ref.dtype)

return ref_img_norm

def guidance_step(x:Tensor, ref_meth:str, ref_img_norm:Tensor, ref_hgf:float, denoised:Tensor, sigmas, i):
# TODO: make scheduling for hgf?
if ref_meth == 'euler':
# rescale `ref_img` to match distribution
avg_t = denoised.mean(dim=[1, 2, 3], keepdim=True)
std_t = denoised.std (dim=[1, 2, 3], keepdim=True)
ref_img_shift = ref_img_norm * std_t + avg_t

d = to_d(x, sigmas[i], ref_img_shift)
dt = (sigmas[i + 1] - sigmas[i]) * ref_hgf
x = x + d * dt

elif ref_meth == 'linear':
# rescale `ref_img` to match distribution
avg_t = x.mean(dim=[1, 2, 3], keepdim=True)
std_t = x.std (dim=[1, 2, 3], keepdim=True)
ref_img_shift = ref_img_norm * std_t + avg_t
x = (1 - ref_hgf) * x + ref_hgf * ref_img_shift

return x


def get_upscale_resolution(p:StableDiffusionProcessing, upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int) -> Tuple[bool, Tuple[int, int]]:
if upscale_meth == 'None':
return False, (p.width, p.height)
Expand Down Expand Up @@ -633,14 +598,14 @@ def ui(self, is_img2img):

with gr.Group() as tab_momentum:
with gr.Row(variant='compact', visible=not DEFAULT_GRID_SEARCH) as tab_m:
momentum = gr.Slider(label='Momentum (current)', minimum=0.75, maximum=1.0, value=lambda: DEFAULT_MOMENTUM)
momentum_hist = gr.Slider(label='Momentum (history)', minimum=0.0, maximum=1.0, value=lambda: DEFAULT_MOMENTUM_HIST)
momentum = gr.Slider(label='Momentum current', minimum=0.75, maximum=1.0, value=lambda: DEFAULT_MOMENTUM)
momentum_hist = gr.Slider(label='Momentum history', minimum=0.0, maximum=1.0, value=lambda: DEFAULT_MOMENTUM_HIST)
with gr.Row(variant='compact', visible=DEFAULT_GRID_SEARCH) as tab_m_gs:
momentum_gs = gr.Text(label='Momentum (current) search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_GS)
momentum_hist_gs = gr.Text(label='Momentum (history) search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_HIST_GS)
momentum_gs = gr.Text(label='Momentum current search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_GS)
momentum_hist_gs = gr.Text(label='Momentum history search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_HIST_GS)
with gr.Row(variant='compact'):
momentum_sign = gr.Radio(label='Momentum sign', value=lambda: DEFAULT_MOMENTUM_SIGN, choices=CHOICE_MOMENTUM_SIGN)
momentum_hist_init = gr.Radio(label='Momentum history init', value=lambda: DEFAULT_MOMENTUM_HIST_INIT, choices=CHOICE_MOMENTUM_HIST_INIT)
momentum_sign = gr.Radio(label='Momentum sign (experimental)', value=lambda: DEFAULT_MOMENTUM_SIGN, choices=CHOICE_MOMENTUM_SIGN)
momentum_hist_init = gr.Radio(label='Momentum history init (experimental)', value=lambda: DEFAULT_MOMENTUM_HIST_INIT, choices=CHOICE_MOMENTUM_HIST_INIT)
use_grid_search.change(fn=lambda x: [gr_show(not x), gr_show(x)], inputs=use_grid_search, outputs=[tab_m, tab_m_gs])

with gr.Group(visible=False) as tab_file:
Expand Down

0 comments on commit 3327ccc

Please sign in to comment.