From cbef2c1422a642f0c1835e853eb4b0ddd14005da Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 3 Jun 2023 11:03:54 -0400 Subject: [PATCH 1/3] life --- lib_neutral_prompt/cfg_denoiser_hijack.py | 42 +++++++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index 4968f6a..520c15b 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -1,3 +1,4 @@ +import torchvision from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser from modules import script_callbacks, sd_samplers, shared from typing import Tuple, List @@ -213,24 +214,51 @@ def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> t def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor: """ - Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights. - Salience maps are calculated to identify regions of interest. - The blended result combines `normal` and vector information in salient regions. + Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights. + Salience maps are calculated to identify regions of interest. + The blended result combines `normal` and vector information in salient regions. """ - salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, weight in vectors] + salience_maps = [get_salience(normal)] + [get_salience(vector, specific=False) for vector, weight in vectors] mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0) result = torch.zeros_like(normal) for mask_i, (vector, weight) in enumerate(vectors, start=1): - vector_mask = ((mask == mask_i).float()) + vector_mask = (mask == mask_i).float() + + blur = torchvision.transforms.GaussianBlur(3, 1.) + + # vector_mask = blur(life(vector_mask)) + # vector_mask = life(vector_mask) + vector_mask = life(vector_mask, lambda board, neighbors: ((board == 1) & (neighbors >= 3)).float()) + result += weight * vector_mask * (vector - normal) return result -def get_salience(vector: torch.Tensor) -> torch.Tensor: - return torch.softmax(torch.abs(vector).flatten(), dim=0).reshape_as(vector) +def life(board: torch.Tensor, rules = lambda board, neighbors: ((neighbors == 3) | ((board == 1) & (neighbors >= 3) & (neighbors <= 4))).float()): + # todo: maybe try to increase kernel size? + kernel = torch.tensor( + [[1, 1, 1], + [1, 0, 1], + [1, 1, 1]], + dtype=board.dtype, + device=board.device, + ) + neighbors = torch.nn.functional.conv2d( + board.unsqueeze(0), + kernel.repeat(board.size(0), 1, 1).unsqueeze(0), + padding=1 + ).squeeze(0) + return rules(board, neighbors) + + +def get_salience(vector: torch.Tensor, specific: bool = False) -> torch.Tensor: + k = 1 + if specific: + k = 50 + return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector) sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get( From 8848e62d9c6e1f798acdfae1189f90174e00f55c Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 3 Jun 2023 12:58:22 -0400 Subject: [PATCH 2/3] 3d conv --- lib_neutral_prompt/cfg_denoiser_hijack.py | 31 +++++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index 520c15b..8501584 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -219,7 +219,7 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float] The blended result combines `normal` and vector information in salient regions. """ - salience_maps = [get_salience(normal)] + [get_salience(vector, specific=False) for vector, weight in vectors] + salience_maps = [get_salience(normal)] + [get_salience(vector, specific=True) for vector, weight in vectors] mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0) result = torch.zeros_like(normal) @@ -230,34 +230,37 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float] # vector_mask = blur(life(vector_mask)) # vector_mask = life(vector_mask) - vector_mask = life(vector_mask, lambda board, neighbors: ((board == 1) & (neighbors >= 3)).float()) + vector_mask = life(vector_mask, lambda board, neighbors: (board == 1) & (neighbors >= board.size(0) * 9 / 2)) result += weight * vector_mask * (vector - normal) return result -def life(board: torch.Tensor, rules = lambda board, neighbors: ((neighbors == 3) | ((board == 1) & (neighbors >= 3) & (neighbors <= 4))).float()): - # todo: maybe try to increase kernel size? +def life(board: torch.Tensor, rules = None): + if rules is None: + rules = lambda board, neighbors: (neighbors == board.size(0) * 3) | ((board == 1) & (neighbors >= board.size(0) * 3) & (neighbors <= board.size(0) * 4)) kernel = torch.tensor( - [[1, 1, 1], - [1, 0, 1], - [1, 1, 1]], + [[[1, 1, 1], + [1, 1, 1], + [1, 1, 1]]] * board.size(0), dtype=board.dtype, device=board.device, ) - neighbors = torch.nn.functional.conv2d( - board.unsqueeze(0), - kernel.repeat(board.size(0), 1, 1).unsqueeze(0), - padding=1 - ).squeeze(0) - return rules(board, neighbors) + padded_board = torch.concatenate([board.clone(), board[:-1].clone()], dim=0) + padded_board = torch.nn.functional.pad(padded_board, (1, 1, 1, 1, 0, 0), mode='constant', value=0) + neighbors = torch.nn.functional.conv3d( + padded_board.unsqueeze(0).unsqueeze(0), + kernel.unsqueeze(0).unsqueeze(0), + padding=0, + ).squeeze(0).squeeze(0) + return rules(board, neighbors - board).float() def get_salience(vector: torch.Tensor, specific: bool = False) -> torch.Tensor: k = 1 if specific: - k = 50 + k = 30 return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector) From c0566907d4d215357cb0425acc394c1cc3688d76 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 4 Jun 2023 04:12:34 -0400 Subject: [PATCH 3/3] backup --- lib_neutral_prompt/cfg_denoiser_hijack.py | 27 +++++- lib_neutral_prompt/dev.py | 110 ++++++++++++++++++++++ 2 files changed, 133 insertions(+), 4 deletions(-) create mode 100644 lib_neutral_prompt/dev.py diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index 8501584..db530bc 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -1,10 +1,10 @@ -import torchvision from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser from modules import script_callbacks, sd_samplers, shared from typing import Tuple, List import dataclasses import functools import torch +import torchvision import sys import textwrap @@ -198,7 +198,14 @@ def visit_composite_prompt( index += child.accept(neutral_prompt_parser.FlatSizeVisitor()) - aux_cond_delta += salient_blend(cond_delta, salient_cond_deltas) + from lib_neutral_prompt import dev + dev.reload(dev) + aux_cond_delta += dev.salient_blend(cond_delta, salient_cond_deltas) + + if isinstance(len(that.children) >= 2 and that.children[1], neutral_prompt_parser.CompositePrompt): + if len(dev.mask_images) == shared.state.sampling_steps: + dev.mask_images[0].save(rf'D:\sd\images-test\saliency-masks\{uuid.uuid4()}.gif', save_all=True, append_images=dev.mask_images[1:], duration=100, loop=0) + return aux_cond_delta @@ -212,6 +219,10 @@ def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> t return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2 +mask_images = [] +import uuid + + def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor: """ Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights. @@ -230,8 +241,16 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float] # vector_mask = blur(life(vector_mask)) # vector_mask = life(vector_mask) - vector_mask = life(vector_mask, lambda board, neighbors: (board == 1) & (neighbors >= board.size(0) * 9 / 2)) + for _ in range(2): + vector_mask = life(vector_mask, lambda board, neighbors: (board == 1) & (neighbors >= board.size(0) * 6)) + + display_mask = vector_mask[:3] + vector_mask[3] / 3 + display_mask = torch.nn.functional.interpolate(display_mask.unsqueeze(0), scale_factor=8, mode='nearest-exact').squeeze(0) + + if shared.state.sampling_step == 0 and len(mask_images) >= 2: + mask_images.clear() + mask_images.append(torchvision.transforms.functional.to_pil_image(display_mask)) result += weight * vector_mask * (vector - normal) return result @@ -260,7 +279,7 @@ def life(board: torch.Tensor, rules = None): def get_salience(vector: torch.Tensor, specific: bool = False) -> torch.Tensor: k = 1 if specific: - k = 30 + k = 20 return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector) diff --git a/lib_neutral_prompt/dev.py b/lib_neutral_prompt/dev.py new file mode 100644 index 0000000..4d1473a --- /dev/null +++ b/lib_neutral_prompt/dev.py @@ -0,0 +1,110 @@ +from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser +from modules import shared +import importlib +import torch +import torchvision +import sys +import textwrap +from typing import List, Tuple + + +mask_images = [] + + +def reload(self): + global mask_images + images = mask_images + importlib.reload(self) + self.mask_images = images + + +def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor: + if (normal == 0).all(): + if shared.state.sampling_step <= 0: + warn_projection_not_found() + + return vector + + return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2 + + +def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor: + """ + Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights. + Salience maps are calculated to identify regions of interest. + The blended result combines `normal` and vector information in salient regions. + """ + + salience_maps = [get_salience(normal)] + [get_salience(vector, specific=True) for vector, weight in vectors] + mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0) + + result = torch.zeros_like(normal) + for mask_i, (vector, weight) in enumerate(vectors, start=1): + vector_mask = (mask == mask_i).float() + + blur = torchvision.transforms.GaussianBlur(3, 1.) + + # vector_mask = blur(life(vector_mask)) + # vector_mask = life(vector_mask) + for _ in range(6): + vector_mask = life(vector_mask, lambda board, neighbors: (board == 1) & (neighbors >= board.size(0) * 5)) + + for _ in range(2): + vector_mask = life(vector_mask, thickify_rules) + + display_mask = vector_mask[:3] * 2/3 + vector_mask[3] / 3 + display_mask = torch.nn.functional.interpolate(display_mask.unsqueeze(0), scale_factor=8, mode='nearest-exact').squeeze(0) + + if shared.state.sampling_step == 0 and len(mask_images) >= 2: + mask_images.clear() + + mask_images.append(torchvision.transforms.functional.to_pil_image(display_mask)) + result += weight * vector_mask * (vector - normal) + + return result + + +def thickify_rules(board, neighbors): + population = board + neighbors + return (board == 1) | (population >= 4) + + +def life(board: torch.Tensor, rules = None): + if rules is None: + rules = lambda board, neighbors: (neighbors == board.size(0) * 3) | ((board == 1) & (neighbors >= board.size(0) * 3) & (neighbors <= board.size(0) * 4)) + kernel = torch.tensor( + [[[1, 1, 1], + [1, 1, 1], + [1, 1, 1]]] * board.size(0), + dtype=board.dtype, + device=board.device, + ) + padded_board = torch.concatenate([board.clone(), board[:-1].clone()], dim=0) + padded_board = torch.nn.functional.pad(padded_board, (1, 1, 1, 1, 0, 0), mode='constant', value=0) + neighbors = torch.nn.functional.conv3d( + padded_board.unsqueeze(0).unsqueeze(0), + kernel.unsqueeze(0).unsqueeze(0), + padding=0, + ).squeeze(0).squeeze(0) + return rules(board, neighbors - board).float() + + +def get_salience(vector: torch.Tensor, specific: bool = False) -> torch.Tensor: + k = 1 + if specific: + k = 20 + return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector) + + +def warn_projection_not_found(): + console_warn(''' + Could not find a projection for one or more AND_PERP prompts + These prompts will NOT be made perpendicular + ''') + + +def console_warn(message): + if not global_state.verbose: + return + + print(f'\n[sd-webui-neutral-prompt extension]{textwrap.dedent(message)}', file=sys.stderr)