-
Notifications
You must be signed in to change notification settings - Fork 15
/
utils.py
124 lines (97 loc) · 5.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import math
from PIL import Image, ImageDraw, ImageFont
import logging
import os
def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions):
loss = 0
object_number = len(bboxes)
if object_number == 0:
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
for attn_map_integrated in attn_maps_mid:
attn_map = attn_map_integrated
#
b, i, j = attn_map.shape
H = W = int(math.sqrt(i))
for obj_idx in range(object_number):
obj_loss = 0
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
for obj_box in bboxes[obj_idx]:
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
mask[y_min: y_max, x_min: x_max] = 1
for obj_position in object_positions[obj_idx]:
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
obj_loss += torch.mean((1 - activation_value) ** 2)
loss += (obj_loss/len(object_positions[obj_idx]))
for attn_map_integrated in attn_maps_up[0]:
attn_map = attn_map_integrated
#
b, i, j = attn_map.shape
H = W = int(math.sqrt(i))
for obj_idx in range(object_number):
obj_loss = 0
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
for obj_box in bboxes[obj_idx]:
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
mask[y_min: y_max, x_min: x_max] = 1
for obj_position in object_positions[obj_idx]:
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
dim=-1)
obj_loss += torch.mean((1 - activation_value) ** 2)
loss += (obj_loss / len(object_positions[obj_idx]))
loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid)))
return loss
def Pharse2idx(prompt, phrases):
phrases = [x.strip() for x in phrases.split(';')]
prompt_list = prompt.strip('.').split(' ')
object_positions = []
for obj in phrases:
obj_position = []
for word in obj.split(' '):
obj_first_index = prompt_list.index(word) + 1
obj_position.append(obj_first_index)
object_positions.append(obj_position)
return object_positions
def draw_box(pil_img, bboxes, phrases, save_path):
draw = ImageDraw.Draw(pil_img)
font = ImageFont.truetype('./FreeMono.ttf', 25)
phrases = [x.strip() for x in phrases.split(';')]
for obj_bboxes, phrase in zip(bboxes, phrases):
for obj_bbox in obj_bboxes:
x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
draw.rectangle([int(x_0 * 512), int(y_0 * 512), int(x_1 * 512), int(y_1 * 512)], outline='red', width=5)
draw.text((int(x_0 * 512) + 5, int(y_0 * 512) + 5), phrase, font=font, fill=(255, 0, 0))
pil_img.save(save_path)
def setup_logger(save_path, logger_name):
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Create a file handler to write logs to a file
file_handler = logging.FileHandler(os.path.join(save_path, f"{logger_name}.log"))
file_handler.setLevel(logging.INFO)
# Create a formatter to format log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Set the formatter for the file handler
file_handler.setFormatter(formatter)
# Add the file handler to the logger
logger.addHandler(file_handler)
return logger
def load_text_inversion(text_encoder, tokenizer, placeholder_token, embedding_ckp_path):
num_added_tokens = tokenizer.add_tokens(placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
learned_embedding = torch.load(embedding_ckp_path)
token_embeds[placeholder_token_id] = learned_embedding[placeholder_token]
return text_encoder, tokenizer