diff --git a/README.md b/README.md index 7fe4a28..20eef66 100644 --- a/README.md +++ b/README.md @@ -117,11 +117,11 @@ How hard ref_img guidance works: - momentum_hist: (float), momentum of memorized difference history - the larger, approving current history - the smaller, approving former histories - - momentum_hist_init: (float), init value of history, aka. the genesis + - momentum_hist_init: (categorical), init value of history, aka. the genesis (experimental) - `zero`: use the first denoised latent - `rand_init`: just use the init latent noise - `rand_new`: use a new guassian noise - - momentum_sign: (categorical), momentum direction to apply correction + - momentum_sign: (categorical), momentum direction to apply correction (experimental) - `pos`: correct by direction of history momentum, affirming the history - `neg`: correct by opposite direction of history momentum, denying the history - `rand`: random choose from above at each sampling step diff --git a/scripts/sonar.py b/scripts/sonar.py index 1fa4e4d..ce9a459 100644 --- a/scripts/sonar.py +++ b/scripts/sonar.py @@ -1,13 +1,11 @@ import random from PIL import Image from PIL.Image import Image as PILImage -from typing import List, Tuple from pprint import pprint as pp import inspect import gradio as gr import torch -from torch import Tensor import numpy as np from tqdm.auto import trange @@ -22,6 +20,9 @@ from k_diffusion.sampling import to_d, get_ancestral_step from ldm.models.diffusion.ddpm import LatentDiffusion +from typing import List, Tuple, Union, Literal +from torch import Tensor + if 'global const': DEFAULT_ENABLE = False DEFAULT_UPSCALE = False @@ -248,26 +249,12 @@ def sample_naive_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal ref_min_step = settings['ref_min_step'] ref_max_step = settings['ref_max_step'] - # memorize delta momentum - if momentum_hist_init == 'zero': history_d = 0 - elif momentum_hist_init == 'rand_init': history_d = x - elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x) - else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}') + # init hist momentum memory + history_d = init_hist_d(momentum_hist_init, x) # prepare ref_img latent if ref_img is not None: - x_ref = torch.from_numpy(np.asarray(ref_img)).moveaxis(2, 0) # [C=3, H, W] - x_ref = (x_ref / 255) * 2 - 1 - x_ref = x_ref.unsqueeze(dim=0).expand(x.shape[0], -1, -1, -1) # [B, C=3, H, W] - x_ref = x_ref.to(sd_model.first_stage_model.device) - - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): - latent_ref = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x_ref)) # [B, C=4, H=64, W=64] - - avg_s = latent_ref.mean(dim=[2, 3], keepdim=True) - std_s = latent_ref.std (dim=[2, 3], keepdim=True) - ref_img_norm = (latent_ref - avg_s) / std_s - ref_img_norm = ref_img_norm.to(x_ref.dtype) + ref_img_norm = init_ref_img(sd_model, ref_img, x.shape[0]) s_in = x.new_ones([x.shape[0]]) n_steps = len(sigmas) - 1 @@ -284,50 +271,13 @@ def sample_naive_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal # momentum step if momentum < 1.0: - # decide correct direction - sign = momentum_sign - if sign == 'rand': sign = random.choice(['pos', 'neg']) - - # correct current `d` with momentum - p = 1.0 - momentum - if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d - elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d - else: raise ValueError(f'unknown momentum sign {sign}') - - # Euler method with momentum - x = x + momentum_d * dt - - # update momentum history - q = 1.0 - momentum_hist - if (isinstance(history_d, int) and history_d == 0): - history_d = momentum_d - else: - if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d - elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d - else: raise ValueError(f'unknown momentum sign {sign}') + x, history_d = momentum_step(x, d, dt, momentum, momentum_hist, momentum_sign, history_d) else: - # Euler method original - x = x + d * dt + x = x + d * dt # Euler method original # guidance step if ref_img is not None and ref_hgf and ref_min_step <= i <= ref_max_step: - # TODO: make scheduling for hgf? - if ref_meth == 'euler': - # rescale `ref_img` to match distribution - avg_t = denoised.mean(dim=[1, 2, 3], keepdim=True) - std_t = denoised.std (dim=[1, 2, 3], keepdim=True) - ref_img_shift = ref_img_norm * std_t + avg_t - - d = to_d(x, sigmas[i], ref_img_shift) - dt = (sigmas[i + 1] - sigmas[i]) * ref_hgf - x = x + d * dt - if ref_meth == 'linear': - # rescale `ref_img` to match distribution - avg_t = x.mean(dim=[1, 2, 3], keepdim=True) - std_t = x.std (dim=[1, 2, 3], keepdim=True) - ref_img_shift = ref_img_norm * std_t + avg_t - - x = (1 - ref_hgf) * x + ref_hgf * ref_img_shift + x = guidance_step(x, ref_meth, ref_img_norm, ref_hgf, denoised, sigmas, i) # noise step alike ancestral if i <= n_steps - 1: @@ -342,10 +292,7 @@ def sample_euler_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal momentum_hist = settings['momentum_hist'] momentum_hist_init = settings['momentum_hist_init'] - if momentum_hist_init == 'zero': history_d = 0 - elif momentum_hist_init == 'rand_init': history_d = x - elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x) - else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}') + history_d = init_hist_d(momentum_hist_init, x) s_in = x.new_ones([x.shape[0]]) n_steps = len(sigmas) - 1 @@ -361,35 +308,10 @@ def sample_euler_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_args={}, cal d = to_d(x, sigma_hat, denoised) dt = sigmas[i + 1] - sigma_hat - if 'momentum step': - # decide correct direction - sign = momentum_sign - action_pool = ['pos', 'neg'] - if sign == 'rand' : sign = random.choice(action_pool) - elif sign == 'pos_neg': sign = action_pool[int(i < n_steps // 2)] - elif sign == 'neg_pos': sign = action_pool[int(i > n_steps // 2)] - else: pass - - # correct current `d` with momentum - p = 1.0 - momentum - if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d - elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d - else: raise ValueError(f'unknown momentum sign {sign}') - - # Euler method with momentum - x = x + momentum_d * dt - - # update momentum history - q = 1.0 - momentum_hist - if (isinstance(history_d, int) and history_d == 0): - history_d = momentum_d - else: - if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d - elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d - else: raise ValueError(f'unknown momentum sign {sign}') + if momentum < 1.0: + x, history_d = momentum_step(x, d, dt, momentum, momentum_hist, momentum_sign, history_d) else: - # Euler method original - x = x + d * dt + x = x + d * dt # Euler method original return x @@ -404,10 +326,7 @@ def sample_euler_ancestral_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_ar show_featmap(x, 'before sample') # 记录梯度历史的惯性 - if momentum_hist_init == 'zero': history_d = 0 - elif momentum_hist_init == 'rand_init': history_d = x - elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x) - else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}') + history_d = init_hist_d(momentum_hist_init, x) s_in = x.new_ones([x.shape[0]]) # [B=1] n_steps = len(sigmas) - 1 @@ -429,35 +348,10 @@ def sample_euler_ancestral_ex(model:CFGDenoiser, x:Tensor, sigmas:List, extra_ar # ancestral scheduling (down) dt = sigma_down - sigmas[i] # scalar, 沿着梯度移动的步长 - if 'momentum step': - # decide correct direction - sign = momentum_sign - action_pool = ['pos', 'neg'] - if sign == 'rand' : sign = random.choice(action_pool) - elif sign == 'pos_neg': sign = action_pool[int(i < n_steps // 2)] - elif sign == 'neg_pos': sign = action_pool[int(i > n_steps // 2)] - else: pass - - # correct current `d` with momentum - p = 1.0 - momentum - if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d - elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d - else: raise ValueError(f'unknown momentum sign {sign}') - - # Euler method with momentum - x = x + momentum_d * dt - - # update momentum history - q = 1.0 - momentum_hist - if (isinstance(history_d, int) and history_d == 0): - history_d = momentum_d - else: - if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d - elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d - else: raise ValueError(f'unknown momentum sign {sign}') + if momentum < 1.0: + x, history_d = momentum_step(x, d, dt, momentum, momentum_hist, momentum_sign, history_d) else: - # Euler method original - x = x + d * dt + x = x + d * dt # Euler method original show_featmap(x, f'x + d x dt (step {i}); sigma_down={sigma_down:.4f}') # 作画内容逐渐显露 # ancestral scheduling (up) @@ -581,6 +475,77 @@ def sample_img2img(self, p:StableDiffusionProcessing, x:Tensor, noise:Tensor, # ↑↑↑ the above is modified from 'modules/sd_samplers.py' ↑↑↑ +def init_hist_d(momentum_hist_init:str, x:Tensor) -> Union[Literal[0], Tensor]: + # memorize delta momentum + if momentum_hist_init == 'zero': history_d = 0 + elif momentum_hist_init == 'rand_init': history_d = x + elif momentum_hist_init == 'rand_new': history_d = torch.randn_like(x) + else: raise ValueError(f'unknown momentum_hist_init: {momentum_hist_init}') + return history_d + +def momentum_step(x:Tensor, d:Tensor, dt:Tensor, momentum:float, momentum_hist:float, momentum_sign:str, history_d): + # decide correct direction + sign = momentum_sign + if sign == 'rand': sign = random.choice(['pos', 'neg']) + + # correct current `d` with momentum + p = 1.0 - momentum + if sign == 'pos': momentum_d = (1.0 - p) * d + p * history_d + elif sign == 'neg': momentum_d = (1.0 + p) * d - p * history_d + else: raise ValueError(f'unknown momentum sign {sign}') + + # Euler method with momentum + x = x + momentum_d * dt + + # update momentum history + q = 1.0 - momentum_hist + if (isinstance(history_d, int) and history_d == 0): + history_d = momentum_d + else: + if sign == 'pos': history_d = (1.0 - q) * history_d + q * momentum_d + elif sign == 'neg': history_d = (1.0 + q) * history_d - q * momentum_d + else: raise ValueError(f'unknown momentum sign {sign}') + + return x, history_d + +def init_ref_img(sd_model:LatentDiffusion, ref_img:PILImage, B:int) -> Tensor: + x_ref = torch.from_numpy(np.asarray(ref_img)).moveaxis(2, 0) # [C=3, H, W] + x_ref = (x_ref / 255) * 2 - 1 + x_ref = x_ref.unsqueeze(dim=0).expand(B, -1, -1, -1) # [B, C=3, H, W] + x_ref = x_ref.to(sd_model.first_stage_model.device) + + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + latent_ref = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x_ref)) # [B, C=4, H=64, W=64] + + avg_s = latent_ref.mean(dim=[2, 3], keepdim=True) + std_s = latent_ref.std (dim=[2, 3], keepdim=True) + ref_img_norm = (latent_ref - avg_s) / std_s + ref_img_norm = ref_img_norm.to(x_ref.dtype) + + return ref_img_norm + +def guidance_step(x:Tensor, ref_meth:str, ref_img_norm:Tensor, ref_hgf:float, denoised:Tensor, sigmas, i): + # TODO: make scheduling for hgf? + if ref_meth == 'euler': + # rescale `ref_img` to match distribution + avg_t = denoised.mean(dim=[1, 2, 3], keepdim=True) + std_t = denoised.std (dim=[1, 2, 3], keepdim=True) + ref_img_shift = ref_img_norm * std_t + avg_t + + d = to_d(x, sigmas[i], ref_img_shift) + dt = (sigmas[i + 1] - sigmas[i]) * ref_hgf + x = x + d * dt + + elif ref_meth == 'linear': + # rescale `ref_img` to match distribution + avg_t = x.mean(dim=[1, 2, 3], keepdim=True) + std_t = x.std (dim=[1, 2, 3], keepdim=True) + ref_img_shift = ref_img_norm * std_t + avg_t + x = (1 - ref_hgf) * x + ref_hgf * ref_img_shift + + return x + + def get_upscale_resolution(p:StableDiffusionProcessing, upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int) -> Tuple[bool, Tuple[int, int]]: if upscale_meth == 'None': return False, (p.width, p.height) @@ -633,14 +598,14 @@ def ui(self, is_img2img): with gr.Group() as tab_momentum: with gr.Row(variant='compact', visible=not DEFAULT_GRID_SEARCH) as tab_m: - momentum = gr.Slider(label='Momentum (current)', minimum=0.75, maximum=1.0, value=lambda: DEFAULT_MOMENTUM) - momentum_hist = gr.Slider(label='Momentum (history)', minimum=0.0, maximum=1.0, value=lambda: DEFAULT_MOMENTUM_HIST) + momentum = gr.Slider(label='Momentum current', minimum=0.75, maximum=1.0, value=lambda: DEFAULT_MOMENTUM) + momentum_hist = gr.Slider(label='Momentum history', minimum=0.0, maximum=1.0, value=lambda: DEFAULT_MOMENTUM_HIST) with gr.Row(variant='compact', visible=DEFAULT_GRID_SEARCH) as tab_m_gs: - momentum_gs = gr.Text(label='Momentum (current) search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_GS) - momentum_hist_gs = gr.Text(label='Momentum (history) search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_HIST_GS) + momentum_gs = gr.Text(label='Momentum current search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_GS) + momentum_hist_gs = gr.Text(label='Momentum history search list', max_lines=1, value=lambda: DEFAULT_MOMENTUM_HIST_GS) with gr.Row(variant='compact'): - momentum_sign = gr.Radio(label='Momentum sign', value=lambda: DEFAULT_MOMENTUM_SIGN, choices=CHOICE_MOMENTUM_SIGN) - momentum_hist_init = gr.Radio(label='Momentum history init', value=lambda: DEFAULT_MOMENTUM_HIST_INIT, choices=CHOICE_MOMENTUM_HIST_INIT) + momentum_sign = gr.Radio(label='Momentum sign (experimental)', value=lambda: DEFAULT_MOMENTUM_SIGN, choices=CHOICE_MOMENTUM_SIGN) + momentum_hist_init = gr.Radio(label='Momentum history init (experimental)', value=lambda: DEFAULT_MOMENTUM_HIST_INIT, choices=CHOICE_MOMENTUM_HIST_INIT) use_grid_search.change(fn=lambda x: [gr_show(not x), gr_show(x)], inputs=use_grid_search, outputs=[tab_m, tab_m_gs]) with gr.Group(visible=False) as tab_file: