Skip to content

Commit

Permalink
Commit Salience-DETR code
Browse files Browse the repository at this point in the history
  • Loading branch information
xiuqhou committed Mar 26, 2024
1 parent 4683c93 commit b649d01
Show file tree
Hide file tree
Showing 97 changed files with 27,739 additions and 0 deletions.
114 changes: 114 additions & 0 deletions configs/salience_detr/salience_detr_resnet50_800_1333.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from torch import nn
# from torchvision.ops import FrozenBatchNorm2d

from models.backbones.resnet import ResNetBackbone
from models.bricks.misc import FrozenBatchNorm2d
from models.bricks.position_encoding import PositionEmbeddingSine
from models.bricks.post_process import PostProcess
from models.bricks.salience_transformer import (
SalienceTransformer,
SalienceTransformerDecoder,
SalienceTransformerDecoderLayer,
SalienceTransformerEncoder,
SalienceTransformerEncoderLayer,
)
from models.bricks.set_criterion import HybridSetCriterion
from models.detectors.salience_detr import SalienceCriterion, SalienceDETR
from models.matcher.hungarian_matcher import HungarianMatcher
from models.necks.channel_mapper import ChannelMapper
from models.necks.repnet import RepVGGPluXNetwork

# mostly changed parameters
embed_dim = 256
num_classes = 91
num_queries = 900
num_feature_levels = 4
transformer_enc_layers = 6
transformer_dec_layers = 6
num_heads = 8
dim_feedforward = 2048

# instantiate model components
position_embedding = PositionEmbeddingSine(embed_dim // 2, temperature=10000, normalize=True, offset=-0.5)

backbone = ResNetBackbone(
"resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,)
)

neck = ChannelMapper(
in_channels=backbone.num_channels,
out_channels=embed_dim,
num_outs=num_feature_levels,
)

transformer = SalienceTransformer(
encoder=SalienceTransformerEncoder(
encoder_layer=SalienceTransformerEncoderLayer(
embed_dim=embed_dim,
n_heads=num_heads,
dropout=0.0,
activation=nn.ReLU(inplace=True),
n_levels=num_feature_levels,
n_points=4,
d_ffn=dim_feedforward,
),
num_layers=transformer_enc_layers,
),
neck=RepVGGPluXNetwork(
in_channels_list=neck.num_channels,
out_channels_list=neck.num_channels,
norm_layer=nn.BatchNorm2d,
activation=nn.SiLU,
groups=4,
),
decoder=SalienceTransformerDecoder(
decoder_layer=SalienceTransformerDecoderLayer(
embed_dim=embed_dim,
n_heads=num_heads,
dropout=0.0,
activation=nn.ReLU(inplace=True),
n_levels=num_feature_levels,
n_points=4,
d_ffn=dim_feedforward,
),
num_layers=transformer_dec_layers,
num_classes=num_classes,
),
num_classes=num_classes,
num_feature_levels=num_feature_levels,
two_stage_num_proposals=num_queries,
level_filter_ratio=(0.4, 0.8, 1.0, 1.0),
layer_filter_ratio=(1.0, 0.8, 0.6, 0.6, 0.4, 0.2),
)

matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)

weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2}
weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2})
weight_dict.update({
k + f"_{i}": v
for i in range(transformer_dec_layers - 1)
for k, v in weight_dict.items()
})
weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2})
weight_dict.update({"loss_salience": 2})

criterion = HybridSetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0)
foreground_criterion = SalienceCriterion(noise_scale=0.0, alpha=0.25, gamma=2.0)
postprocessor = PostProcess(select_box_nums_for_evaluation=300)

# combine above components to instantiate the model
model = SalienceDETR(
backbone=backbone,
neck=neck,
position_embedding=position_embedding,
transformer=transformer,
criterion=criterion,
focus_criterion=foreground_criterion,
postprocessor=postprocessor,
num_classes=num_classes,
num_queries=num_queries,
aux_loss=True,
min_size=800,
max_size=1333,
)
47 changes: 47 additions & 0 deletions configs/train_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from torch import optim

from datasets.coco import CocoDetection
from transforms import presets
from optimizer import param_dict

# Commonly changed training configurations
num_epochs = 12 # train epochs
batch_size = 2 # total_batch_size = #GPU x batch_size
num_workers = 4 # workers for pytorch DataLoader
pin_memory = True # whether pin_memory for pytorch DataLoader
print_freq = 50 # frequency to print logs
starting_epoch = 0
max_norm = 0.1 # clip gradient norm

output_dir = None # path to save checkpoints, default for None: checkpoints/{model_name}
find_unused_parameters = False # useful for debugging distributed training

# define dataset for train
coco_path = "data/coco" # /PATH/TO/YOUR/COCODIR
train_transform = presets.detr # see transforms/presets to choose a transform
train_dataset = CocoDetection(
img_folder=f"{coco_path}/train2017",
ann_file=f"{coco_path}/annotations/instances_train2017.json",
transforms=train_transform,
train=True,
)
test_dataset = CocoDetection(
img_folder=f"{coco_path}/val2017",
ann_file=f"{coco_path}/annotations/instances_val2017.json",
transforms=None, # the eval_transform is integrated in the model
)

# model config to train
model_path = "configs/salience_detr/salience_detr_resnet50_800_1333.py"

# specify a checkpoint folder to resume, or a pretrained ".pth" to finetune, for example:
# checkpoints/salience_detr_resnet50_800_1333/train/2024-03-22-09_38_50
# checkpoints/salience_detr_resnet50_800_1333/train/2024-03-22-09_38_50/best_ap.pth
resume_from_checkpoint = None

learning_rate = 1e-4 # initial learning rate
optimizer = optim.AdamW(lr=learning_rate, weight_decay=1e-4, betas=(0.9, 0.999))
lr_scheduler = optim.lr_scheduler.MultiStepLR(milestones=[10], gamma=0.1)

# This define parameter groups with different learning rate
param_dicts = param_dict.finetune_backbone_and_linear_projection(lr=learning_rate)
161 changes: 161 additions & 0 deletions datasets/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os

import albumentations as A
import cv2
import numpy as np
import torchvision

from transforms import v2 as T
from transforms.convert_coco_polys_to_mask import ConvertCocoPolysToMask
from util import datapoints
from util.misc import deepcopy


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(
self,
img_folder,
ann_file,
transforms=None,
train=False,
):
super(CocoDetection, self).__init__(img_folder, ann_file)
self.prepare = ConvertCocoPolysToMask()
self._transforms = transforms
self._transforms = self.update_dataset(self._transforms)
self.train = train

if train:
self._coco_remove_images_without_annotations()

def update_dataset(self, transform):
if isinstance(transform, (T.Compose, A.Compose)):
processed_transforms = []
for trans in transform.transforms:
trans = self.update_dataset(trans)
processed_transforms.append(trans)
return type(transform)(processed_transforms)
if hasattr(transform, "update_dataset"):
transform.update_dataset(self)
return transform

def load_image(self, image_name):
# after comparing the speed of PIL, torchvision and cv2,
# cv2 is chosen as the default backend to load images,
# uncomment the following code to switch among them.

# image = Image.open(os.path.join(self.root, path)).convert('RGB')
# image = torchvision.io.read_image(os.path.join(self.root, path))

# To avoid deadlock between DataLoader and OpenCV
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

# image = cv2.imread(os.path.join(self.root, image_name))
image = cv2.imdecode(np.fromfile(os.path.join(self.root, image_name), dtype=np.uint8), -1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
return image

def get_image_id(self, item: int):
if hasattr(self, "indices"):
item = self.indices[item]
image_id = self.ids[item]
return image_id

def load_image_and_target(self, item: int):
image_id = self.get_image_id(item)
# load images and annotations
image_name = self.coco.loadImgs([image_id])[0]["file_name"]
image = self.load_image(image_name)
target = self.coco.loadAnns(self.coco.getAnnIds([image_id]))
target = dict(image_id=image_id, annotations=target)
image, target = self.prepare((image, target))
return image, target

def data_augmentation(self, image, target):
# preprocess
image = datapoints.Image(image)
bounding_boxes = datapoints.BoundingBox(
target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=image.shape[-2:],
)
labels = target["labels"]
if self._transforms is not None:
image, bounding_boxes, labels = self._transforms(image, bounding_boxes, labels)

return image.data, bounding_boxes.data, labels

def __getitem__(self, item):
image, target = self.load_image_and_target(item)
image, target["boxes"], target["labels"] = self.data_augmentation(image, target)

return deepcopy(image), deepcopy(target)

def __len__(self):
return len(self.indices) if hasattr(self, "indices") else len(self.ids)

def _coco_remove_images_without_annotations(self, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)

min_keypoints_per_image = 10

def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False

ids = []
for ds_idx, img_id in enumerate(self.ids):
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

self.indices = ids


class Object365Detection(CocoDetection):
def load_image_and_target(self, item: int):
image_id = self.get_image_id(item)
# load images and annotations
image_name = self.coco.loadImgs([image_id])[0]["file_name"]
# NOTE: Only for object 365
image_name = os.path.join(*image_name.split(os.sep)[-2:])
if self.train:
image_name = os.path.join("images/train", image_name)
else:
image_name = os.path.join("images/val", image_name)
image = self.load_image(image_name)
target = self.coco.loadAnns(self.coco.getAnnIds([image_id]))
target = dict(image_id=image_id, annotations=target)
image, target = self.prepare((image, target))
return image, target

def __getitem__(self, item):
try:
image, target = self.load_image_and_target(item)
except:
item += 1
image, target = self.load_image_and_target(item)
image, target["boxes"], target["labels"] = self.data_augmentation(image, target)

return deepcopy(image), deepcopy(target)
Loading

0 comments on commit b649d01

Please sign in to comment.