-
Notifications
You must be signed in to change notification settings - Fork 17
/
image_chooser_preview.py
162 lines (139 loc) · 6.98 KB
/
image_chooser_preview.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from server import PromptServer
from nodes import PreviewImage
from comfy.model_management import InterruptProcessingException
from .image_chooser_server import MessageHolder, Cancelled
import torch
import random
class PreviewAndChoose(PreviewImage):
RETURN_TYPES = ("IMAGE","LATENT","MASK","STRING","SEGS")
RETURN_NAMES = ("images","latents","masks","selected","segs")
FUNCTION = "func"
CATEGORY = "image_chooser"
INPUT_IS_LIST=True
OUTPUT_NODE = False
last_ic = {}
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mode" : (["Always pause", "Repeat last selection", "Only pause if batch", "Progress first pick", "Pass through", "Take First n", "Take Last n"],{}),
"count": ("INT", { "default": 1, "min": 1, "max": 999, "step": 1 }),
},
"optional": {"images": ("IMAGE", ), "latents": ("LATENT", ), "masks": ("MASK", ), "segs":("SEGS", ) },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "id":"UNIQUE_ID"},
}
@classmethod
def IS_CHANGED(cls, id, **kwargs):
mode = kwargs.get("mode",["Always pause"])
if (mode[0]!="Repeat last selection" or not id[0] in cls.last_ic): cls.last_ic[id[0]] = random.random()
return cls.last_ic[id[0]]
def func(self, id, **kwargs):
# mode doesn't exist in subclass
self.count = int(kwargs.pop('count', [1,])[0])
mode = kwargs.pop('mode',["Always pause",])[0]
if mode=="Repeat last selection":
print("Here despite 'Repeat last selection' - treat as 'Always pause'")
mode = "Always pause"
if mode=="Always pause":
# pretend it was Repeat last so that the prompt matches if that is selected next time.
# UGH
kwargs['prompt'][0][id[0]]['inputs']['mode'] = "Repeat last selection"
id = id[0]
if id not in MessageHolder.stash:
MessageHolder.stash[id] = {}
my_stash = MessageHolder.stash[id]
DOING_SEGS = 'segs' in kwargs
# enable stashing. If images is None, we are operating in read-from-stash mode
if 'images' in kwargs:
my_stash['images'] = kwargs['images']
my_stash['latents'] = kwargs.get('latents', None)
my_stash['masks'] = kwargs.get('masks', None)
else:
kwargs['images'] = my_stash.get('images', None)
kwargs['latents'] = my_stash.get('latents', None)
kwargs['masks'] = my_stash.get('masks', None)
if (kwargs['images'] is None):
return (None, None, None, "")
# convert list to batch
images_in = torch.cat(kwargs.pop('images')) if not DOING_SEGS else list(i[0,...] for i in kwargs.pop('images'))
latents_in = kwargs.pop('latents', None)
masks_in = torch.cat(kwargs.get('masks')) if kwargs.get('masks', None) is not None else None
segs_in = kwargs.pop('segs', None)
kwargs.pop('masks', None)
latent_samples_in = torch.cat(list(l['samples'] for l in latents_in)) if latents_in is not None else None
self.batch = images_in.shape[0] if not DOING_SEGS else len(images_in)
# any other parameters shouldn't be lists any more...
for x in kwargs: kwargs[x] = kwargs[x][0]
# call PreviewImage base
ret = self.save_images(images=images_in, **kwargs)
# send the images to view
PromptServer.instance.send_sync("early-image-handler", {"id": id, "urls":ret['ui']['images']})
# wait for selection
try:
is_block_condition = (mode == "Always pause" or mode == "Progress first pick" or self.batch > 1)
is_blocking_mode = (mode not in ["Pass through", "Take First n", "Take Last n"])
selections = MessageHolder.waitForMessage(id, asList=True) if (is_blocking_mode and is_block_condition) else [0]
except Cancelled:
raise InterruptProcessingException()
#return (None, None,)
if DOING_SEGS:
segs_out = (segs_in[0][0], list(segs_in[0][1][i] for i in selections if i>=0) )
return(None, None, None, None, segs_out)
return self.batch_up_selections(images_in=images_in, latent_samples_in=latent_samples_in, masks_in=masks_in, selections=selections, mode=mode)
def tensor_bundle(self, tensor_in:torch.Tensor, picks):
if tensor_in is not None and len(picks):
batch = tensor_in.shape[0]
return torch.cat(tuple([tensor_in[(x)%batch].unsqueeze_(0) for x in picks])).reshape([-1]+list(tensor_in.shape[1:]))
else:
return None
def latent_bundle(self, latent_samples_in:torch.Tensor, picks):
if (latent_samples_in is not None and len(picks)):
return { "samples" : self.tensor_bundle(latent_samples_in, picks) }
else:
return None
def batch_up_selections(self, images_in:torch.Tensor, latent_samples_in:torch.Tensor, masks_in:torch.Tensor, selections, mode):
if (mode=="Pass through"):
chosen = range(0, self.batch)
elif (mode=="Take First n"):
end = self.count if self.batch >= self.count else self.batch
chosen = range(0, end)
elif (mode=="Take Last n"):
start = self.batch - self.count if self.batch - self.count >= 0 else 0
chosen = range(start, self.batch)
else:
chosen = [x for x in selections if x>=0]
return (self.tensor_bundle(images_in, chosen), self.latent_bundle(latent_samples_in, chosen), self.tensor_bundle(masks_in, chosen), ",".join(str(x) for x in chosen), None, )
class SimpleChooser(PreviewAndChoose):
RETURN_TYPES = ("IMAGE","LATENT",)
RETURN_NAMES = ("images","latents",)
FUNCTION = "func"
CATEGORY = "image_chooser"
INPUT_IS_LIST=True
OUTPUT_NODE = False
last_ic = {}
@classmethod
def INPUT_TYPES(s):
return {
"required": { "images": ("IMAGE", ), },
"optional": { "latents": ("LATENT", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "id":"UNIQUE_ID"},
}
def func(self, **kwargs):
return super().func(**kwargs)[0:2]
class PreviewAndChooseDouble(PreviewAndChoose):
RETURN_TYPES = ("LATENT","LATENT",)
RETURN_NAMES = ("positive","negative",)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE", ),
"latents": ("LATENT", ),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "id":"UNIQUE_ID"},
}
def batch_up_selections(self, images_in, latent_samples_in, masks_in, selections:list, mode):
divider = selections.index(-1)
latents_out_good = self.latent_bundle(latent_samples_in, selections[:divider])
latents_out_bad = self.latent_bundle(latent_samples_in, selections[divider+1:])
return (latents_out_good, latents_out_bad,)