-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_sd_boxdiff.py
120 lines (99 loc) · 4.72 KB
/
run_sd_boxdiff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pprint
from typing import List
import pyrallis
import torch
from PIL import Image
from config import RunConfig
from pipeline.sdxl_pipeline_boxdiff import BoxDiffPipeline
from utils import ptp_utils, vis_utils
from utils.ptp_utils import AttentionStore
import numpy as np
from utils.drawer import draw_rectangle, DashedImageDraw
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
def load_model(config: RunConfig):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
if config.sd_xl:
stable_diffusion_version = "stabilityai/stable-diffusion-xl-base-1.0"
elif config.sd_2_1:
stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
else:
stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
# If you cannot access the huggingface on your server, you can use the local prepared one.
# stable_diffusion_version = "../../packages/huggingface/hub/stable-diffusion-v1-4"
print(f"Loading model from {stable_diffusion_version}")
stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version,torch_dtype=torch.float16 ).to(device)
return stable
def get_indices_to_alter(stable, prompt: str) -> List[int]:
token_idx_to_word = {idx: stable.tokenizer.decode(t)
for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1}
pprint.pprint(token_idx_to_word)
token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to "
"alter (e.g., 2,5): ")
token_indices = [int(i) for i in token_indices.split(",")]
print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}")
return token_indices
def run_on_prompt(prompt: List[str],
model: BoxDiffPipeline,
controller: AttentionStore,
token_indices: List[int],
seed: torch.Generator,
config: RunConfig) -> Image.Image:
# if controller is not None:
# ptp_utils.register_attention_control(model, controller)
outputs = model(prompt=prompt,
height= 512,
width=512,
attention_store=controller,
indices_to_alter=token_indices,
attention_res=config.attention_res,
guidance_scale=config.guidance_scale,
generator=seed,
num_inference_steps=config.n_inference_steps,
max_iter_to_alter=config.max_iter_to_alter,
run_standard_sd=config.run_standard_sd,
thresholds=config.thresholds,
scale_factor=config.scale_factor,
scale_range=config.scale_range,
smooth_attentions=config.smooth_attentions,
sigma=config.sigma,
kernel_size=config.kernel_size,
sd_2_1=config.sd_2_1,
bbox=config.bbox,
weight_loss= config.weight_loss,
config=config)
image = outputs.images[0]
return image
@pyrallis.wrap()
def main(config: RunConfig):
stable = load_model(config)
token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices
if len(config.bbox[0]) == 0:
config.bbox = draw_rectangle()
images = []
for seed in config.seeds:
print(f"Current seed is : {seed}")
g = torch.Generator('cuda').manual_seed(seed)
controller = AttentionStore()
image = run_on_prompt(prompt=config.prompt,
model=stable,
controller=controller,
token_indices=token_indices,
seed=g,
config=config)
prompt_output_path = config.output_path / config.prompt[:100]
prompt_output_path.mkdir(exist_ok=True, parents=True)
image.save(prompt_output_path / f'{seed}.png')
images.append(image)
canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220)
draw = DashedImageDraw(canvas)
for i in range(len(config.bbox)):
x1, y1, x2, y2 = config.bbox[i]
draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5)
canvas.save(prompt_output_path / f'{seed}_canvas.png')
# save a grid of results across all seeds
joined_image = vis_utils.get_image_grid(images)
joined_image.save(config.output_path / f'{config.prompt}.png')
if __name__ == '__main__':
main()