Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Game of life filter for salience maps #16

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import functools
import torch
import torchvision
import sys
import textwrap

Expand Down Expand Up @@ -197,7 +198,14 @@ def visit_composite_prompt(

index += child.accept(neutral_prompt_parser.FlatSizeVisitor())

aux_cond_delta += salient_blend(cond_delta, salient_cond_deltas)
from lib_neutral_prompt import dev
dev.reload(dev)
aux_cond_delta += dev.salient_blend(cond_delta, salient_cond_deltas)

if isinstance(len(that.children) >= 2 and that.children[1], neutral_prompt_parser.CompositePrompt):
if len(dev.mask_images) == shared.state.sampling_steps:
dev.mask_images[0].save(rf'D:\sd\images-test\saliency-masks\{uuid.uuid4()}.gif', save_all=True, append_images=dev.mask_images[1:], duration=100, loop=0)

return aux_cond_delta


Expand All @@ -211,26 +219,68 @@ def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> t
return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2


mask_images = []
import uuid


def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor:
"""
Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights.
Salience maps are calculated to identify regions of interest.
The blended result combines `normal` and vector information in salient regions.
Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights.
Salience maps are calculated to identify regions of interest.
The blended result combines `normal` and vector information in salient regions.
"""

salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, weight in vectors]
salience_maps = [get_salience(normal)] + [get_salience(vector, specific=True) for vector, weight in vectors]
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)

result = torch.zeros_like(normal)
for mask_i, (vector, weight) in enumerate(vectors, start=1):
vector_mask = ((mask == mask_i).float())
vector_mask = (mask == mask_i).float()

blur = torchvision.transforms.GaussianBlur(3, 1.)

# vector_mask = blur(life(vector_mask))
# vector_mask = life(vector_mask)
for _ in range(2):
vector_mask = life(vector_mask, lambda board, neighbors: (board == 1) & (neighbors >= board.size(0) * 6))

display_mask = vector_mask[:3] + vector_mask[3] / 3
display_mask = torch.nn.functional.interpolate(display_mask.unsqueeze(0), scale_factor=8, mode='nearest-exact').squeeze(0)

if shared.state.sampling_step == 0 and len(mask_images) >= 2:
mask_images.clear()

mask_images.append(torchvision.transforms.functional.to_pil_image(display_mask))
result += weight * vector_mask * (vector - normal)

return result


def get_salience(vector: torch.Tensor) -> torch.Tensor:
return torch.softmax(torch.abs(vector).flatten(), dim=0).reshape_as(vector)
def life(board: torch.Tensor, rules = None):
if rules is None:
rules = lambda board, neighbors: (neighbors == board.size(0) * 3) | ((board == 1) & (neighbors >= board.size(0) * 3) & (neighbors <= board.size(0) * 4))
kernel = torch.tensor(
[[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]] * board.size(0),
dtype=board.dtype,
device=board.device,
)
padded_board = torch.concatenate([board.clone(), board[:-1].clone()], dim=0)
padded_board = torch.nn.functional.pad(padded_board, (1, 1, 1, 1, 0, 0), mode='constant', value=0)
neighbors = torch.nn.functional.conv3d(
padded_board.unsqueeze(0).unsqueeze(0),
kernel.unsqueeze(0).unsqueeze(0),
padding=0,
).squeeze(0).squeeze(0)
return rules(board, neighbors - board).float()


def get_salience(vector: torch.Tensor, specific: bool = False) -> torch.Tensor:
k = 1
if specific:
k = 20
return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)


sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
Expand Down
110 changes: 110 additions & 0 deletions lib_neutral_prompt/dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
from modules import shared
import importlib
import torch
import torchvision
import sys
import textwrap
from typing import List, Tuple


mask_images = []


def reload(self):
global mask_images
images = mask_images
importlib.reload(self)
self.mask_images = images


def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
if (normal == 0).all():
if shared.state.sampling_step <= 0:
warn_projection_not_found()

return vector

return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2


def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor:
"""
Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights.
Salience maps are calculated to identify regions of interest.
The blended result combines `normal` and vector information in salient regions.
"""

salience_maps = [get_salience(normal)] + [get_salience(vector, specific=True) for vector, weight in vectors]
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)

result = torch.zeros_like(normal)
for mask_i, (vector, weight) in enumerate(vectors, start=1):
vector_mask = (mask == mask_i).float()

blur = torchvision.transforms.GaussianBlur(3, 1.)

# vector_mask = blur(life(vector_mask))
# vector_mask = life(vector_mask)
for _ in range(6):
vector_mask = life(vector_mask, lambda board, neighbors: (board == 1) & (neighbors >= board.size(0) * 5))

for _ in range(2):
vector_mask = life(vector_mask, thickify_rules)

display_mask = vector_mask[:3] * 2/3 + vector_mask[3] / 3
display_mask = torch.nn.functional.interpolate(display_mask.unsqueeze(0), scale_factor=8, mode='nearest-exact').squeeze(0)

if shared.state.sampling_step == 0 and len(mask_images) >= 2:
mask_images.clear()

mask_images.append(torchvision.transforms.functional.to_pil_image(display_mask))
result += weight * vector_mask * (vector - normal)

return result


def thickify_rules(board, neighbors):
population = board + neighbors
return (board == 1) | (population >= 4)


def life(board: torch.Tensor, rules = None):
if rules is None:
rules = lambda board, neighbors: (neighbors == board.size(0) * 3) | ((board == 1) & (neighbors >= board.size(0) * 3) & (neighbors <= board.size(0) * 4))
kernel = torch.tensor(
[[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]] * board.size(0),
dtype=board.dtype,
device=board.device,
)
padded_board = torch.concatenate([board.clone(), board[:-1].clone()], dim=0)
padded_board = torch.nn.functional.pad(padded_board, (1, 1, 1, 1, 0, 0), mode='constant', value=0)
neighbors = torch.nn.functional.conv3d(
padded_board.unsqueeze(0).unsqueeze(0),
kernel.unsqueeze(0).unsqueeze(0),
padding=0,
).squeeze(0).squeeze(0)
return rules(board, neighbors - board).float()


def get_salience(vector: torch.Tensor, specific: bool = False) -> torch.Tensor:
k = 1
if specific:
k = 20
return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)


def warn_projection_not_found():
console_warn('''
Could not find a projection for one or more AND_PERP prompts
These prompts will NOT be made perpendicular
''')


def console_warn(message):
if not global_state.verbose:
return

print(f'\n[sd-webui-neutral-prompt extension]{textwrap.dedent(message)}', file=sys.stderr)