Skip to content

Commit

Permalink
Merge pull request #1 from AlvardBarseghyan/nightowls_custom_cosine
Browse files Browse the repository at this point in the history
Nightowls custom cosine
  • Loading branch information
AlvardBarseghyan authored Feb 8, 2023
2 parents ae39d63 + 2a642a1 commit 56b688a
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
*.jpeg
*.png
*.ipynb
*.txt
47 changes: 47 additions & 0 deletions cosine_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import torch
from pl_train import LightningMAE

from dataset import MAEDataset
from torchmetrics.functional import pairwise_cosine_similarity as cos_dist

import models_mae


def cosine_distance_torch(x1, x2=None, eps=1e-8):
x2 = x1 if x2 is None else x2
w1 = x1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
return torch.mm(x1, x2.t()) / (w1 * w2.t()) #.clamp(min=eps)


BATCH_SIZE = 1
arch='mae_vit_large_patch16'
model_mae = getattr(models_mae, arch)()

chkpt_dir = '/mnt/2tb/alla/mae/mae_contastive/lightning_logs/version_12/checkpoints/epoch=30-step=31.ckpt'
chkpt_dir_old = '/mnt/2tb/hrant/checkpoints/mae_models/mae_visualize_vit_large.pth'
checkpoint = torch.load(chkpt_dir_old, map_location='cpu')
msg = model_mae.load_state_dict(checkpoint['model'], strict=False)
model_mae = LightningMAE.load_from_checkpoint(chkpt_dir, model=model_mae)

model_mae.eval()

root = '/mnt/2tb/hrant/FAIR1M/fair1m_1000/train1000/'
path_ann = os.path.join(root, 'few_shot_8.json')
path_imgs = os.path.join(root, 'images')
dataset = MAEDataset(path_ann, path_imgs, resize_image=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

dl = next(iter(dataloader))
img = torch.einsum('nhwc->nchw', dl['image'])
img_enc = model_mae(img.float())
img_enc = img_enc.reshape(-1, img_enc.shape[-1])

cos_torchmetrics = cos_dist(img_enc, img_enc)
cos_custom = cosine_distance_torch(img_enc)

print((cos_torchmetrics.reshape(-1) != cos_custom.reshape(-1)).sum())
ind = cos_torchmetrics != cos_custom
print(cos_torchmetrics[ind] , cos_custom[ind])
print((cos_torchmetrics.reshape(-1).abs() - cos_custom.reshape(-1).abs()).sum())
12 changes: 6 additions & 6 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,17 @@ def __getitem__(self, idx):

img_boxes = [x['bbox'] for x in self.anns['annotations'] if x['image_id'] == img_id] # [x1, y1, w, h]
img_labels = [x['category_id'] for x in self.anns['annotations'] if x['image_id'] == img_id] # label
img_segmentation = [x['segmentation'][0] for x in self.anns['annotations'] if x['image_id'] == img_id]
# img_segmentation = [x['segmentation'][0] for x in self.anns['annotations'] if x['image_id'] == img_id]

x_scale = IMAGE_SIZE / image.shape[2]
y_scale = IMAGE_SIZE / image.shape[1]

black_image = np.zeros((IMAGE_SIZE, IMAGE_SIZE))

for box, label, seg in zip(img_boxes, img_labels, img_segmentation):
seg = self.scale_box(seg, (x_scale, y_scale))
pts = np.array([[seg[0], seg[1]], [seg[2], seg[3]], [seg[4], seg[5]], [seg[6], seg[7]]])
black_image = cv2.fillPoly(black_image, [pts], (label, 0))
for box, label in zip(img_boxes, img_labels):
box = self.scale_box(box, (x_scale, y_scale))
pts = np.array([[box[0], box[1]], [box[0] + box[2], box[1] + box[3]]])
black_image = cv2.rectangle(black_image, pts[0], pts[1], (label, 0), -1)

# black_image = cv2.resize(black_image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_AREA)

Expand All @@ -142,7 +142,7 @@ def __getitem__(self, idx):

target = {}
target['image'] = image
# target['black_image'] = black_image
target['black_image'] = black_image
target['file_name'] = img_path
# target['boxes'] = np.array(img_boxes)
# target['labels'] = np.array(img_labels)
Expand Down
23 changes: 13 additions & 10 deletions models_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,31 @@ def forward_encoder(self, x, mask_ratio):

# mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
# print('encoder, mask tokens shape:', mask_tokens.detach().numpy())
# x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
# x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
x_ = x[:, 1:, :] # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# print('encoder:', x.shape)
# return x[:,idx], mask, ids_restore
return x, mask, ids_restore

def forward_decoder(self, x, ids_restore):
# embed tokens
idx = torch.arange(140, 160)
x[:, idx] = torch.ones((x.shape[0], idx.shape[0], x.shape[-1])) * 1000
x = self.decoder_embed(x)

# append mask tokens to sequence
print("decoder:", x.shape, ids_restore.shape)
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
print('decoder, mask tokens shape:', mask_tokens.shape)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
print('')
# mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
# print('decoder, mask tokens shape:', mask_tokens.shape)
# x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
# x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed

print('self.decoder_pos_embed', self.decoder_pos_embed)
# idx = torch.arange(140, 160)
# x[:, idx] = torch.ones((x.shape[0], idx.shape[0], x.shape[-1])) * 1000
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
Expand Down
50 changes: 31 additions & 19 deletions pl_train.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import os
import torch
import torch.nn as nn
import torchmetrics
import pytorch_lightning as pl

from dataset import MAEDataset
import models_mae

BATCH_SIZE = 40
EPOCHS = 1 # 100
BATCH_SIZE = 24
EPOCHS = 500
DEVICE = 'cpu'
continue_from_checkpoint = False


def cosine_distance_torch(x1, x2=None):
x2 = x1 if x2 is None else x2
w1 = x1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
return 1 - torch.mm(x1, x2.t()) / (w1 * w2.t())


class ContrastiveLoss(nn.Module):
def __init__(self, num_classes=5, margin=1.0) -> None:
super().__init__()
Expand All @@ -19,12 +27,9 @@ def __init__(self, num_classes=5, margin=1.0) -> None:

def forward(self, img_enc_1, labels, img_enc_2=None):
if not img_enc_2:
cos_dist = 1 - torchmetrics.functional.pairwise_cosine_similarity(img_enc_1)
cos_dist = cosine_distance_torch(img_enc_1)
else:
cos_dist = 1 - torchmetrics.functional.pairwise_cosine_similarity(img_enc_1, img_enc_2)

print(cos_dist.__dir__())
print(cos_dist.grad)
cos_dist = cosine_distance_torch(img_enc_1, img_enc_2)

# d = 0 means y1 and y2 are supposed to be same
# d = 1 means y1 and y2 are supposed to be different
Expand Down Expand Up @@ -65,8 +70,9 @@ def __init__(self, model, l1=0.5, lr=1e-4, num_classes=5, margin=1):
self.model_mae = model
self.l1 = l1
self.lr = lr
self.min_loss = 10
self.criterion = ContrastiveLoss(num_classes=num_classes, margin=margin)
# self.save_hyperparameters(ignore=['model'])
# self.save_hyperparameters()

def training_step(self, batch, batch_idx):

Expand All @@ -80,6 +86,10 @@ def training_step(self, batch, batch_idx):
self.log('train_loss', total_loss)
print(f'Iter: {batch_idx}, pos_loss: {loss[0].item()}, neg_loss = {self.l1} * {loss[1].item()}, loss: {total_loss.item()}')

if self.min_loss > total_loss:
self.min_loss = total_loss.item()
torch.save(self.model_mae.state_dict(), "/mnt/2tb/alla/mae/mae_contastive/nightowls/best_model.pth")

return total_loss


Expand All @@ -95,10 +105,10 @@ def forward(self, img):
if __name__ == '__main__':

#### init dataset ####
root = '/mnt/2tb/hrant/FAIR1M/fair1m_1000/train1000/'
path_ann = os.path.join(root, 'few_shot_8.json')
path_imgs = os.path.join(root, 'images')
dataset = MAEDataset(path_ann, path_imgs, resize_image=True)

path_ann = './annotations/few_shot_8_nightowls.json'
path_imgs = '/home/ani/nightowls_stage_3/'
dataset = MAEDataset(path_ann, path_imgs, intersection_threshold=0.01, resize_image=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

#### init model ####
Expand All @@ -108,18 +118,20 @@ def forward(self, img):
if continue_from_checkpoint:
# chkpt_dir = 'best_model.pth'
chkpt_dir = '/mnt/2tb/alla/mae/mae_contastive/lightning_logs/version_12/checkpoints/epoch=30-step=31.ckpt'
checkpoint = torch.load(chkpt_dir, map_location='cpu')
checkpoint = torch.load(chkpt_dir, map_location=DEVICE)
msg = model_mae.load_state_dict(checkpoint, strict=False)

else:
# chkpt_dir = '/mnt/2tb/hrant/checkpoints/mae_models/mae_visualize_vit_large.pth'
chkpt_dir = '/mnt/2tb/hrant/checkpoints/mae_models/mae_visualize_vit_large_ganloss.pth'
checkpoint = torch.load(chkpt_dir, map_location='cuda')
checkpoint = torch.load(chkpt_dir, map_location=DEVICE)
msg = model_mae.load_state_dict(checkpoint['model'], strict=False)
# chkpt_dir = '/mnt/2tb/alla/mae/mae_contastive/custom_cosine_sim/lightning_logs/version_5/checkpoints/epoch=15-step=16.ckpt'
# model_mae = LightningMAE.load_from_checkpoint(chkpt_dir, model=model_mae)
# model_mae = model_mae.model_mae


model = LightningMAE(model_mae, l1=1)
trainer = pl.Trainer(limit_predict_batches=BATCH_SIZE, max_epochs=EPOCHS, log_every_n_steps=1,\
default_root_dir="/mnt/2tb/alla/mae/mae_contastive/") #, accelerator='gpu',\
# devices=1, )
model = LightningMAE(model_mae, num_classes=3, lr=0.0001, l1=1)
trainer = pl.Trainer(logger=True, enable_checkpointing=True, limit_predict_batches=BATCH_SIZE, max_epochs=EPOCHS, log_every_n_steps=1, \
default_root_dir="/mnt/2tb/alla/mae/mae_contastive/nightowls", accelerator=DEVICE, devices=1, )
trainer.fit(model=model, train_dataloaders=dataloader)
91 changes: 91 additions & 0 deletions to_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from tqdm import tqdm
import json
from PIL import Image
import pandas as pd


nightowls_dir = '/home/ani/nightowls_stage_3/'

def coco_dict():
coco_format = {}
coco_format['images'] = []
coco_format['annotations'] = []
coco_format['categories'] = [
{
'id': 1,
'name': 'pedestrian',
'supercategory': 'pedestrian'
},
{
'id': 2,
'name': 'motorbike driver',
'supercategory': 'motorbike driver'
},
{
'id': 3,
'name': 'motorbike driver',
'supercategory': 'motorbike driver'
}
]

return coco_format


def nightowls_annotations(label_filename):
annotations = pd.read_csv(label_filename, header=None, sep=' ')
img_boxes = annotations.iloc[:, 1:].values # [x1, y1, x2, y2]
img_labels = annotations.iloc[:, 0].values # label
return img_boxes, img_labels


def create_coco(path, img_dir, label_dir):
image_names = os.listdir(os.path.join(path, img_dir)) # returns list of img names without absolute path
image_names = [x for x in image_names if '.png' in x]
max_img_size, min_img_size = -1, 100000000
nightowls_coco_format = coco_dict()

for img_id, img_name in tqdm(enumerate(image_names), total=len(image_names)):
img_filename = os.path.join(path, img_dir, img_name)
# print(img_filename)
# if '58c58167bc260130acfebf96' in img_filename:
label_filename = os.path.join(path, label_dir, img_name.replace('.png', '.txt'))
img = Image.open(img_filename)

width, height = img.size
max_img_size = max(max_img_size, height, width)
min_img_size = min(min_img_size, height, width)
tmp_img_dct = {
'file_name': img_filename,
'height': height,
'width': width,
'id': img_id
}

nightowls_coco_format['images'].append(tmp_img_dct)

img_boxes, img_labels = nightowls_annotations(label_filename)

bbox_id = 0
for boxes, label in zip(img_boxes, img_labels):
bbox_width, bbox_height = boxes[2] - boxes[0], boxes[3] - boxes[1]
# print(bbox_width, bbox_height)
tmp_annotation_dct = {
'image_id': img_id,
'category_id': int(label),
'bbox': [int(boxes[0]), int(boxes[1]), int(bbox_width), int(bbox_height)],
'id': bbox_id,
'iscrowd': 0,
'area': int(bbox_width * bbox_height)
}
nightowls_coco_format['annotations'].append(tmp_annotation_dct)
bbox_id += 1

return nightowls_coco_format, max_img_size, min_img_size


if __name__ == "__main__":
nightowls_train, max_, min_ = create_coco(nightowls_dir, './', './')

with open('./annotations/few_shot_8_nightowls.json', 'w') as no:
json.dump(nightowls_train, no)

0 comments on commit 56b688a

Please sign in to comment.