Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding DINO's contrastive sampling #45

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
with torch.cuda.amp.autocast(enabled=args.amp):
if need_tgt_for_training:
outputs, mask_dict = model(samples, dn_args=(targets, args.scalar, args.label_noise_scale,
args.box_noise_scale, args.num_patterns))
args.box_noise_scale, args.num_patterns, args.contrastive))
loss_dict = criterion(outputs, targets, mask_dict)
else:
outputs = model(samples)
Expand Down Expand Up @@ -82,6 +82,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
scaler.update()
else:
# original backward function
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
Expand Down
29 changes: 23 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import random
import time
from pathlib import Path
from os import path
import os, sys
from typing import Optional


from util.logger import setup_logger

import numpy as np
Expand All @@ -24,6 +24,7 @@
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_DABDETR, build_dab_deformable_detr, build_dab_deformable_detr_deformable_encoder_only
from models import build_dab_dino_deformable_detr
from util.utils import clean_state_dict


Expand All @@ -39,6 +40,12 @@ def get_args_parser():
help="label noise ratio to flip")
parser.add_argument('--box_noise_scale', default=0.4, type=float,
help="box noise scale to shift and scale")
parser.add_argument('--contrastive', action="store_true",
help="use contrastive training.")
parser.add_argument('--use_mqs', action="store_true",
help="use mixed query selection from DINO.")
parser.add_argument('--use_lft', action="store_true",
help="use look forward twice from DINO.")

# about lr
parser.add_argument('--lr', default=1e-4, type=float,
Expand All @@ -50,14 +57,15 @@ def get_args_parser():
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--lr_drop', default=40, type=int)
parser.add_argument('--override_resumed_lr_drop', default=False, action='store_true')
parser.add_argument('--drop_lr_now', action="store_true", help="load checkpoint and drop for 12epoch setting")
parser.add_argument('--save_checkpoint_interval', default=10, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')

# Model parameters
parser.add_argument('--modelname', '-m', type=str, required=True, choices=['dn_dab_detr', 'dn_dab_deformable_detr',
'dn_dab_deformable_detr_deformable_encoder_only'])
'dn_dab_deformable_detr_deformable_encoder_only', 'dn_dab_dino_deformable_detr'])
parser.add_argument('--frozen_weights', type=str, default=None,
help="Path to the pretrained model. If set, only the mask head will be trained")

Expand Down Expand Up @@ -94,6 +102,8 @@ def get_args_parser():
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=300, type=int,
help="Number of query slots")
parser.add_argument('--num_results', default=300, type=int,
help="Number of detection results")
parser.add_argument('--pre_norm', action='store_true',
help="Using pre-norm in the Transformer blocks.")
parser.add_argument('--num_select', default=300, type=int,
Expand Down Expand Up @@ -170,7 +180,7 @@ def get_args_parser():
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--debug', action='store_true',
help="For debug only. It will perform only a few steps during trainig and val.")
parser.add_argument('--find_unused_params', action='store_true')
parser.add_argument('--find_unused_params', default=False, action='store_true')

parser.add_argument('--save_results', action='store_true',
help="For eval only. Save the outputs for all images.")
Expand All @@ -196,6 +206,8 @@ def build_model_main(args):
model, criterion, postprocessors = build_dab_deformable_detr(args)
elif args.modelname.lower() == 'dn_dab_deformable_detr_deformable_encoder_only':
model, criterion, postprocessors = build_dab_deformable_detr_deformable_encoder_only(args)
elif args.modelname.lower() == 'dn_dab_dino_deformable_detr':
model, criterion, postprocessors = build_dab_dino_deformable_detr(args)
else:
raise NotImplementedError

Expand All @@ -222,8 +234,8 @@ def main(args):
logger.info('local_rank: {}'.format(args.local_rank))
logger.info("args: " + str(args) + '\n')

if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
#if args.frozen_weights is not None:
# assert args.masks, "Frozen training is meant for segmentation only"
print(args)

device = torch.device(args.device)
Expand Down Expand Up @@ -293,7 +305,7 @@ def main(args):
model_without_ddp.detr.load_state_dict(checkpoint['model'])

output_dir = Path(args.output_dir)
if args.resume:
if args.resume and (args.resume.startswith('https') or path.exists(args.resume)):
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
Expand All @@ -303,6 +315,11 @@ def main(args):
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if args.override_resumed_lr_drop:
print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
lr_scheduler.step_size = args.lr_drop
lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint['epoch'] + 1

if args.drop_lr_now:
Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .DN_DAB_DETR import build_DABDETR
from .dn_dab_deformable_detr import build_dab_deformable_detr
from .dn_dab_deformable_detr_deformable_encoder_only import build_dab_deformable_detr_deformable_encoder_only
from .dn_dab_dino_deformable_detr import build_dab_dino_deformable_detr
14 changes: 14 additions & 0 deletions models/dn_dab_dino_deformable_detr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ------------------------------------------------------------------------
# DAB-DETR
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DModified from eformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

from .dab_deformable_detr import build_dab_dino_deformable_detr
142 changes: 142 additions & 0 deletions models/dn_dab_dino_deformable_detr/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# ------------------------------------------------------------------------
# DAB-DETR
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from util.misc import NestedTensor, is_main_process

from .position_encoding import build_position_encoding


class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.

Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""

def __init__(self, n, eps=1e-5):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = self.eps
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias


class BackboneBase(nn.Module):

def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool):
super().__init__()
for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
if return_interm_layers:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
self.strides = [8, 16, 32]
self.num_channels = [512, 1024, 2048]
else:
return_layers = {'layer4': "0"}
self.strides = [32]
self.num_channels = [2048]
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out


class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
norm_layer = FrozenBatchNorm2d
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=norm_layer)
assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
super().__init__(backbone, train_backbone, return_interm_layers)
if dilation:
self.strides[-1] = self.strides[-1] // 2


class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
self.num_channels = backbone.num_channels

def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in sorted(xs.items()):
out.append(x)

# position encoding
for x in out:
pos.append(self[1](x).to(x.tensors.dtype))

return out, pos


def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks or (args.num_feature_levels > 1)
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
return model
Loading