-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
97 changed files
with
27,739 additions
and
0 deletions.
There are no files selected for viewing
114 changes: 114 additions & 0 deletions
114
configs/salience_detr/salience_detr_resnet50_800_1333.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from torch import nn | ||
# from torchvision.ops import FrozenBatchNorm2d | ||
|
||
from models.backbones.resnet import ResNetBackbone | ||
from models.bricks.misc import FrozenBatchNorm2d | ||
from models.bricks.position_encoding import PositionEmbeddingSine | ||
from models.bricks.post_process import PostProcess | ||
from models.bricks.salience_transformer import ( | ||
SalienceTransformer, | ||
SalienceTransformerDecoder, | ||
SalienceTransformerDecoderLayer, | ||
SalienceTransformerEncoder, | ||
SalienceTransformerEncoderLayer, | ||
) | ||
from models.bricks.set_criterion import HybridSetCriterion | ||
from models.detectors.salience_detr import SalienceCriterion, SalienceDETR | ||
from models.matcher.hungarian_matcher import HungarianMatcher | ||
from models.necks.channel_mapper import ChannelMapper | ||
from models.necks.repnet import RepVGGPluXNetwork | ||
|
||
# mostly changed parameters | ||
embed_dim = 256 | ||
num_classes = 91 | ||
num_queries = 900 | ||
num_feature_levels = 4 | ||
transformer_enc_layers = 6 | ||
transformer_dec_layers = 6 | ||
num_heads = 8 | ||
dim_feedforward = 2048 | ||
|
||
# instantiate model components | ||
position_embedding = PositionEmbeddingSine(embed_dim // 2, temperature=10000, normalize=True, offset=-0.5) | ||
|
||
backbone = ResNetBackbone( | ||
"resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,) | ||
) | ||
|
||
neck = ChannelMapper( | ||
in_channels=backbone.num_channels, | ||
out_channels=embed_dim, | ||
num_outs=num_feature_levels, | ||
) | ||
|
||
transformer = SalienceTransformer( | ||
encoder=SalienceTransformerEncoder( | ||
encoder_layer=SalienceTransformerEncoderLayer( | ||
embed_dim=embed_dim, | ||
n_heads=num_heads, | ||
dropout=0.0, | ||
activation=nn.ReLU(inplace=True), | ||
n_levels=num_feature_levels, | ||
n_points=4, | ||
d_ffn=dim_feedforward, | ||
), | ||
num_layers=transformer_enc_layers, | ||
), | ||
neck=RepVGGPluXNetwork( | ||
in_channels_list=neck.num_channels, | ||
out_channels_list=neck.num_channels, | ||
norm_layer=nn.BatchNorm2d, | ||
activation=nn.SiLU, | ||
groups=4, | ||
), | ||
decoder=SalienceTransformerDecoder( | ||
decoder_layer=SalienceTransformerDecoderLayer( | ||
embed_dim=embed_dim, | ||
n_heads=num_heads, | ||
dropout=0.0, | ||
activation=nn.ReLU(inplace=True), | ||
n_levels=num_feature_levels, | ||
n_points=4, | ||
d_ffn=dim_feedforward, | ||
), | ||
num_layers=transformer_dec_layers, | ||
num_classes=num_classes, | ||
), | ||
num_classes=num_classes, | ||
num_feature_levels=num_feature_levels, | ||
two_stage_num_proposals=num_queries, | ||
level_filter_ratio=(0.4, 0.8, 1.0, 1.0), | ||
layer_filter_ratio=(1.0, 0.8, 0.6, 0.6, 0.4, 0.2), | ||
) | ||
|
||
matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2) | ||
|
||
weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} | ||
weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) | ||
weight_dict.update({ | ||
k + f"_{i}": v | ||
for i in range(transformer_dec_layers - 1) | ||
for k, v in weight_dict.items() | ||
}) | ||
weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) | ||
weight_dict.update({"loss_salience": 2}) | ||
|
||
criterion = HybridSetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0) | ||
foreground_criterion = SalienceCriterion(noise_scale=0.0, alpha=0.25, gamma=2.0) | ||
postprocessor = PostProcess(select_box_nums_for_evaluation=300) | ||
|
||
# combine above components to instantiate the model | ||
model = SalienceDETR( | ||
backbone=backbone, | ||
neck=neck, | ||
position_embedding=position_embedding, | ||
transformer=transformer, | ||
criterion=criterion, | ||
focus_criterion=foreground_criterion, | ||
postprocessor=postprocessor, | ||
num_classes=num_classes, | ||
num_queries=num_queries, | ||
aux_loss=True, | ||
min_size=800, | ||
max_size=1333, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from torch import optim | ||
|
||
from datasets.coco import CocoDetection | ||
from transforms import presets | ||
from optimizer import param_dict | ||
|
||
# Commonly changed training configurations | ||
num_epochs = 12 # train epochs | ||
batch_size = 2 # total_batch_size = #GPU x batch_size | ||
num_workers = 4 # workers for pytorch DataLoader | ||
pin_memory = True # whether pin_memory for pytorch DataLoader | ||
print_freq = 50 # frequency to print logs | ||
starting_epoch = 0 | ||
max_norm = 0.1 # clip gradient norm | ||
|
||
output_dir = None # path to save checkpoints, default for None: checkpoints/{model_name} | ||
find_unused_parameters = False # useful for debugging distributed training | ||
|
||
# define dataset for train | ||
coco_path = "data/coco" # /PATH/TO/YOUR/COCODIR | ||
train_transform = presets.detr # see transforms/presets to choose a transform | ||
train_dataset = CocoDetection( | ||
img_folder=f"{coco_path}/train2017", | ||
ann_file=f"{coco_path}/annotations/instances_train2017.json", | ||
transforms=train_transform, | ||
train=True, | ||
) | ||
test_dataset = CocoDetection( | ||
img_folder=f"{coco_path}/val2017", | ||
ann_file=f"{coco_path}/annotations/instances_val2017.json", | ||
transforms=None, # the eval_transform is integrated in the model | ||
) | ||
|
||
# model config to train | ||
model_path = "configs/salience_detr/salience_detr_resnet50_800_1333.py" | ||
|
||
# specify a checkpoint folder to resume, or a pretrained ".pth" to finetune, for example: | ||
# checkpoints/salience_detr_resnet50_800_1333/train/2024-03-22-09_38_50 | ||
# checkpoints/salience_detr_resnet50_800_1333/train/2024-03-22-09_38_50/best_ap.pth | ||
resume_from_checkpoint = None | ||
|
||
learning_rate = 1e-4 # initial learning rate | ||
optimizer = optim.AdamW(lr=learning_rate, weight_decay=1e-4, betas=(0.9, 0.999)) | ||
lr_scheduler = optim.lr_scheduler.MultiStepLR(milestones=[10], gamma=0.1) | ||
|
||
# This define parameter groups with different learning rate | ||
param_dicts = param_dict.finetune_backbone_and_linear_projection(lr=learning_rate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import os | ||
|
||
import albumentations as A | ||
import cv2 | ||
import numpy as np | ||
import torchvision | ||
|
||
from transforms import v2 as T | ||
from transforms.convert_coco_polys_to_mask import ConvertCocoPolysToMask | ||
from util import datapoints | ||
from util.misc import deepcopy | ||
|
||
|
||
class CocoDetection(torchvision.datasets.CocoDetection): | ||
def __init__( | ||
self, | ||
img_folder, | ||
ann_file, | ||
transforms=None, | ||
train=False, | ||
): | ||
super(CocoDetection, self).__init__(img_folder, ann_file) | ||
self.prepare = ConvertCocoPolysToMask() | ||
self._transforms = transforms | ||
self._transforms = self.update_dataset(self._transforms) | ||
self.train = train | ||
|
||
if train: | ||
self._coco_remove_images_without_annotations() | ||
|
||
def update_dataset(self, transform): | ||
if isinstance(transform, (T.Compose, A.Compose)): | ||
processed_transforms = [] | ||
for trans in transform.transforms: | ||
trans = self.update_dataset(trans) | ||
processed_transforms.append(trans) | ||
return type(transform)(processed_transforms) | ||
if hasattr(transform, "update_dataset"): | ||
transform.update_dataset(self) | ||
return transform | ||
|
||
def load_image(self, image_name): | ||
# after comparing the speed of PIL, torchvision and cv2, | ||
# cv2 is chosen as the default backend to load images, | ||
# uncomment the following code to switch among them. | ||
|
||
# image = Image.open(os.path.join(self.root, path)).convert('RGB') | ||
# image = torchvision.io.read_image(os.path.join(self.root, path)) | ||
|
||
# To avoid deadlock between DataLoader and OpenCV | ||
cv2.setNumThreads(0) | ||
cv2.ocl.setUseOpenCL(False) | ||
|
||
# image = cv2.imread(os.path.join(self.root, image_name)) | ||
image = cv2.imdecode(np.fromfile(os.path.join(self.root, image_name), dtype=np.uint8), -1) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose(2, 0, 1) | ||
return image | ||
|
||
def get_image_id(self, item: int): | ||
if hasattr(self, "indices"): | ||
item = self.indices[item] | ||
image_id = self.ids[item] | ||
return image_id | ||
|
||
def load_image_and_target(self, item: int): | ||
image_id = self.get_image_id(item) | ||
# load images and annotations | ||
image_name = self.coco.loadImgs([image_id])[0]["file_name"] | ||
image = self.load_image(image_name) | ||
target = self.coco.loadAnns(self.coco.getAnnIds([image_id])) | ||
target = dict(image_id=image_id, annotations=target) | ||
image, target = self.prepare((image, target)) | ||
return image, target | ||
|
||
def data_augmentation(self, image, target): | ||
# preprocess | ||
image = datapoints.Image(image) | ||
bounding_boxes = datapoints.BoundingBox( | ||
target["boxes"], | ||
format=datapoints.BoundingBoxFormat.XYXY, | ||
spatial_size=image.shape[-2:], | ||
) | ||
labels = target["labels"] | ||
if self._transforms is not None: | ||
image, bounding_boxes, labels = self._transforms(image, bounding_boxes, labels) | ||
|
||
return image.data, bounding_boxes.data, labels | ||
|
||
def __getitem__(self, item): | ||
image, target = self.load_image_and_target(item) | ||
image, target["boxes"], target["labels"] = self.data_augmentation(image, target) | ||
|
||
return deepcopy(image), deepcopy(target) | ||
|
||
def __len__(self): | ||
return len(self.indices) if hasattr(self, "indices") else len(self.ids) | ||
|
||
def _coco_remove_images_without_annotations(self, cat_list=None): | ||
def _has_only_empty_bbox(anno): | ||
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | ||
|
||
def _count_visible_keypoints(anno): | ||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | ||
|
||
min_keypoints_per_image = 10 | ||
|
||
def _has_valid_annotation(anno): | ||
# if it's empty, there is no annotation | ||
if len(anno) == 0: | ||
return False | ||
# if all boxes have close to zero area, there is no annotation | ||
if _has_only_empty_bbox(anno): | ||
return False | ||
# keypoints task have a slight different critera for considering | ||
# if an annotation is valid | ||
if "keypoints" not in anno[0]: | ||
return True | ||
# for keypoint detection tasks, only consider valid images those | ||
# containing at least min_keypoints_per_image | ||
if _count_visible_keypoints(anno) >= min_keypoints_per_image: | ||
return True | ||
return False | ||
|
||
ids = [] | ||
for ds_idx, img_id in enumerate(self.ids): | ||
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | ||
anno = self.coco.loadAnns(ann_ids) | ||
if cat_list: | ||
anno = [obj for obj in anno if obj["category_id"] in cat_list] | ||
if _has_valid_annotation(anno): | ||
ids.append(ds_idx) | ||
|
||
self.indices = ids | ||
|
||
|
||
class Object365Detection(CocoDetection): | ||
def load_image_and_target(self, item: int): | ||
image_id = self.get_image_id(item) | ||
# load images and annotations | ||
image_name = self.coco.loadImgs([image_id])[0]["file_name"] | ||
# NOTE: Only for object 365 | ||
image_name = os.path.join(*image_name.split(os.sep)[-2:]) | ||
if self.train: | ||
image_name = os.path.join("images/train", image_name) | ||
else: | ||
image_name = os.path.join("images/val", image_name) | ||
image = self.load_image(image_name) | ||
target = self.coco.loadAnns(self.coco.getAnnIds([image_id])) | ||
target = dict(image_id=image_id, annotations=target) | ||
image, target = self.prepare((image, target)) | ||
return image, target | ||
|
||
def __getitem__(self, item): | ||
try: | ||
image, target = self.load_image_and_target(item) | ||
except: | ||
item += 1 | ||
image, target = self.load_image_and_target(item) | ||
image, target["boxes"], target["labels"] = self.data_augmentation(image, target) | ||
|
||
return deepcopy(image), deepcopy(target) |
Oops, something went wrong.