Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor]: Refactor DETR and Deformable DETR #8763

Merged
merged 134 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
e3c0b69
[Fix] Fix UT to be compatible with pytorch 1.6 (#8707)
jbwang1997 Sep 5, 2022
c9b4782
[Refactor] Refactor anchor head and base head with boxlist (#8625)
jbwang1997 Sep 9, 2022
73eae86
fix: fix config of detr-r18
Li-Qingyun Sep 10, 2022
09d52c9
fix: modified import of MSDeformAttn in PixelDecoder of Mask2Former
Li-Qingyun Sep 10, 2022
0e5b9c5
feat: add TransformerDetector as the base detector of DETR-like detec…
Li-Qingyun Sep 10, 2022
875f2f7
refactor: refactor modules and configs of DETR
Li-Qingyun Sep 10, 2022
0e52589
refactor: refactor DETR-related modules in transformer.py
Li-Qingyun Sep 10, 2022
8714d01
refactor: refactor DETR-related modules in transformer.py
Li-Qingyun Sep 10, 2022
c087022
fix: add type comments in detr.py
Li-Qingyun Sep 11, 2022
8288978
correct trainloop in detr_r50 config
KeiChiTse Sep 12, 2022
8e2cc9c
fix: modify the parent class of DETRHead to BaseModule
Li-Qingyun Sep 16, 2022
0985043
Merge remote-tracking branch 'origin/refactor-detr-3.x' into refactor…
Li-Qingyun Sep 16, 2022
9d6bdbd
refactor: refactor modules and configs of Deformable DETR
Li-Qingyun Sep 16, 2022
2fa2680
fix: modify the usage of num_query
Li-Qingyun Sep 16, 2022
51799be
fix: modify the usage of num_query in configs
Li-Qingyun Sep 16, 2022
542d65b
refactor: replace input_proj of detr with ChannelMapper neck
Li-Qingyun Sep 19, 2022
57dd319
refactor: delete multi_apply in DETRHead.forward()
Li-Qingyun Sep 19, 2022
fd90a64
Update detr_r18_8xb2-500e_coco.py
LYMDLUT Sep 19, 2022
2b9e99a
change the name of detection_transfomer.py to base_detr.py
LYM-fire Sep 19, 2022
ce68982
refactor: modify construct binary masks section of forward_pretransfo…
Li-Qingyun Sep 19, 2022
fbb2998
Merge branch 'refactor-detr-3.x' of github.com:Li-Qingyun/mmdetection…
Li-Qingyun Sep 19, 2022
a952396
Merge remote-tracking branch 'origin/refactor-detr-3.x' into refactor…
LYM-fire Sep 19, 2022
2194168
refactor: utilize abstractmethod
Li-Qingyun Sep 19, 2022
70ba3e2
Merge branch 'refactor-detr-3.x' of github.com:Li-Qingyun/mmdetection…
Li-Qingyun Sep 19, 2022
6fe98ed
update ABCmeta to make sure reload class TransformerDetector
LYM-fire Sep 19, 2022
45e7675
some annotation
LYM-fire Sep 19, 2022
2b8ca69
some annotation
LYM-fire Sep 19, 2022
fa38351
some annotation
LYM-fire Sep 19, 2022
65edfa9
refactor: delete _init_transformer in detectors
Li-Qingyun Sep 19, 2022
3156ba6
refactor: modify args of deformable detr
Li-Qingyun Sep 19, 2022
e1bbdd2
refactor: modify about super().__init__()
Li-Qingyun Sep 19, 2022
20e9821
Update detr_head.py
KeiChiTse Sep 19, 2022
e386e89
Update detr.py
KeiChiTse Sep 19, 2022
c65cf2b
some annotation for head
LYM-fire Sep 19, 2022
45295d4
to make sure the head args the same as detector
LYM-fire Sep 19, 2022
73d47c0
to make sure the head args the same as detector
LYM-fire Sep 19, 2022
d0b5aeb
some bug
LYM-fire Sep 19, 2022
2edd0c9
fix: fix bugs of num_pred in DeformableDETRHead
Li-Qingyun Sep 20, 2022
a265c1a
add kwargs to transformer
LYM-fire Sep 20, 2022
57b9869
support MLP and sineembed position
LYM-fire Sep 20, 2022
63f634a
detele positional encodeing
LYM-fire Sep 20, 2022
fe37bd0
delete useless postnorm
LYM-fire Sep 20, 2022
aa290a6
Revert "add kwargs to transformer"
LYM-fire Sep 20, 2022
b31c011
Update detr_head.py
KeiChiTse Sep 20, 2022
af601d6
Update detr_head.py
KeiChiTse Sep 20, 2022
abcb217
Update base_detr.py
KeiChiTse Sep 20, 2022
63c6b93
Update deformable_detr.py
KeiChiTse Sep 20, 2022
3581a00
to support conditional detr with reload forward_transformer
LYM-fire Sep 20, 2022
10a7075
fix: update config files of Two-stage and Box-refine
Li-Qingyun Sep 21, 2022
6d0051e
replace all bs with batch_size in detr-related files
Li-Qingyun Sep 21, 2022
eb3a9d1
update deformable.py and transformer.py
Li-Qingyun Sep 21, 2022
b09688c
update docstring in base_detr
KeiChiTse Sep 21, 2022
8dfc9c1
update docstring in base_detr, detr
KeiChiTse Sep 21, 2022
b69da4f
doc refine
LYM-fire Sep 21, 2022
99d0a5a
Revert "doc refine"
LYM-fire Sep 21, 2022
d692f69
doc refine
LYM-fire Sep 21, 2022
4fae213
doc refine
LYM-fire Sep 21, 2022
9167465
updabase_detr, detr, and le layers/transformdoc
KeiChiTse Sep 21, 2022
e8c8ea0
update doc in transformer
KeiChiTse Sep 21, 2022
4edceca
fix doc in base_detr
KeiChiTse Sep 21, 2022
56d01a0
add origin repo link
LYM-fire Sep 22, 2022
92fbaef
add origin repo link
LYM-fire Sep 22, 2022
a866e88
refine doc
LYM-fire Sep 22, 2022
702c915
refine doc
LYM-fire Sep 22, 2022
0172ba7
refine doc
LYM-fire Sep 22, 2022
7652c68
refine doc
LYM-fire Sep 22, 2022
db04c47
refine doc
LYM-fire Sep 22, 2022
48cfe23
refine doc
LYM-fire Sep 22, 2022
18990e9
refine doc
LYM-fire Sep 22, 2022
dacbdf5
refine doc
LYM-fire Sep 22, 2022
ec04bb9
doc: add doc of the first edition of Deformable DETR
Li-Qingyun Sep 22, 2022
3e6f7b8
Merge branch 'refactor-detr-3.x' of github.com:Li-Qingyun/mmdetection…
Li-Qingyun Sep 22, 2022
b2cf331
batch_size to bs
LYM-fire Sep 22, 2022
4687bc3
refine doc
LYM-fire Sep 22, 2022
59d7eb0
refine doc
LYM-fire Sep 22, 2022
bb294c6
feat: add config comments of specific module
Li-Qingyun Sep 27, 2022
9271b44
refactor: refactor base DETR class TransformerDetector
Li-Qingyun Sep 29, 2022
ad70c44
fix: fix wrong return typehint of forward_encoder in TransformerDetector
Li-Qingyun Sep 29, 2022
b31301c
refactor: refactor DETR
Li-Qingyun Sep 29, 2022
95eff3d
refactor: refactor Deformable DETR
Li-Qingyun Sep 30, 2022
f9d7d2b
refactor: refactor forward_encoder and pre_decoder
Li-Qingyun Oct 1, 2022
ad7c03f
fix: fix bugs of new edition
Li-Qingyun Oct 1, 2022
13dde26
refactor: small modifications
Li-Qingyun Oct 6, 2022
8685649
fix: move get_reference_points to deformable_encoder
Li-Qingyun Oct 6, 2022
69b0eb0
refactor: merge init_&inter_reference to references in Deformable DETR
Li-Qingyun Oct 6, 2022
015d45b
modify docstring of get_valid_ratio in Deformable DETR
Li-Qingyun Oct 6, 2022
becb862
add some docstring
Li-Qingyun Oct 6, 2022
670052b
doc: add docstring of deformable_detr.py
Li-Qingyun Oct 7, 2022
3027c5c
doc: add docstring of deformable_detr_head.py
Li-Qingyun Oct 8, 2022
9a2b801
doc: modify docstring of deformable detr
Li-Qingyun Oct 8, 2022
0d94371
doc: add docstring of deformable_detr_head.py
Li-Qingyun Oct 8, 2022
ceb59b4
doc: modify docstring of deformable detr
Li-Qingyun Oct 8, 2022
4e301ea
Merge branch 'refactor-detr-3.x' of github.com:Li-Qingyun/mmdetection…
Li-Qingyun Oct 8, 2022
bd86b2f
doc: add docstring of base_detr.py
Li-Qingyun Oct 8, 2022
01ade61
doc: refine docstring of base_detr.py
Li-Qingyun Oct 8, 2022
1c65fd8
doc: refine docstring of base_detr.py
Li-Qingyun Oct 8, 2022
baada0d
a little change of MLP
LYM-fire Oct 8, 2022
ac1c3cb
a little change of MLP
LYM-fire Oct 8, 2022
5e9bb65
a little change of MLP
LYM-fire Oct 8, 2022
f06de57
a little change of MLP
LYM-fire Oct 8, 2022
c3bfa32
refine config
LYM-fire Oct 8, 2022
9459640
refine config
LYM-fire Oct 8, 2022
986fe6a
refine config
LYM-fire Oct 8, 2022
f3b2a62
refine doc string for detr
LYM-fire Oct 8, 2022
131da21
little refine doc string for detr.py
LYM-fire Oct 8, 2022
9070abb
tiny modification
Li-Qingyun Oct 9, 2022
55612e6
doc: refine docstring of detr.py
Li-Qingyun Oct 9, 2022
70acfdc
tiny modifications to resolve the conversations
Li-Qingyun Oct 10, 2022
1251f5c
DETRHead.predict() draft
Li-Qingyun Oct 11, 2022
3e1448c
tiny modifications to resolve conversations
Li-Qingyun Oct 11, 2022
4ddc88d
refactor: modify arg names and forward strategies of bbox_head
Li-Qingyun Oct 11, 2022
3c97233
tiny modifications to resolve the conversations
Li-Qingyun Oct 12, 2022
1f43bcf
support MLP
LYM-fire Oct 14, 2022
9a6fc3f
fix docsting of function pre_decoder
KeiChiTse Oct 14, 2022
6d89bed
fix docsting of function pre_decoder
KeiChiTse Oct 14, 2022
b0ec1c3
fix docstring
KeiChiTse Oct 15, 2022
759db62
modifications for resolving conversations
Li-Qingyun Oct 15, 2022
9ed486e
Merge branch 'refactor-detr-3.x' of https://github.com/Li-Qingyun/mmd…
Li-Qingyun Oct 15, 2022
8324244
refactor: eradicate key_padding_mask args
Li-Qingyun Oct 15, 2022
395c18e
refactor: eradicate key_padding_mask args
Li-Qingyun Oct 15, 2022
00dd7b1
Merge branch 'refactor-detr-3.x' of github.com:Li-Qingyun/mmdetection…
Li-Qingyun Oct 15, 2022
5de00a3
fix: fix bug of deformable detr and resolve some conversations
Li-Qingyun Oct 18, 2022
510e017
refactor: rename base class with DetectionTransformer and other modif…
Li-Qingyun Oct 19, 2022
c9ef49b
fix: fix config of detr
Li-Qingyun Oct 19, 2022
1ddab86
fix the bug of init
LYM-fire Oct 19, 2022
7c8a374
fix: fix init_weight of DETR and Deformable DETR
Li-Qingyun Oct 19, 2022
126d427
Merge branch 'refactor-detr-3.x' of github.com:Li-Qingyun/mmdetection…
Li-Qingyun Oct 19, 2022
4bfa8ad
Merge branch 'refactor-detr' of github.com:open-mmlab/mmdetection int…
Li-Qingyun Oct 19, 2022
1c076bd
resolve conflict
Li-Qingyun Oct 19, 2022
968a14d
fix auto-merge bug
Li-Qingyun Oct 19, 2022
9446409
fix pre-commit bug
Li-Qingyun Oct 19, 2022
314cfcb
refactor: move the position of encoder and decoder
Li-Qingyun Oct 20, 2022
afa84e2
delete Transformer in ci test
LYM-fire Oct 20, 2022
7c9abef
delete Transformer in ci test
LYM-fire Oct 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,13 @@ jobs:
- run: pip install "protobuf <= 3.20.1" && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake
- run:
name: Install mmdet dependencies
# numpy may be downgraded after building pycocotools, which causes `ImportError: numpy.core.multiarray failed to import`
# force reinstall pycocotools to ensure pycocotools being built under the currenct numpy
command: |
python -m pip install git+ssh://git@github.com/open-mmlab/mmengine.git@main
python -m pip install << parameters.mmcv >>
pip install -r requirements/tests.txt -r requirements/optional.txt
pip install --force-reinstall pycocotools
pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
pip install git+https://github.com/cocodataset/panopticapi.git
- run:
Expand Down Expand Up @@ -111,17 +114,18 @@ jobs:
command: |
docker build .circleci/docker -t mmdetection:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
docker run --gpus all -t -d -v /home/circleci/project:/mmdetection -v /home/circleci/mmengine:/mmengine -w /mmdetection --name mmdetection mmdetection:gpu
docker exec mmdetection apt-get install -y git
- run:
name: Install mmdet dependencies
# pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html
command: |
docker exec mmdetection pip install -e /mmengine
docker exec mmdetection pip install << parameters.mmcv >>
pip install -r requirements/tests.txt -r requirements/optional.txt
pip install pycocotools
pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
pip install git+https://github.com/cocodataset/panopticapi.git
python -c 'import mmcv; print(mmcv.__version__)'
docker exec mmdetection pip install -r requirements/tests.txt -r requirements/optional.txt
docker exec mmdetection pip install pycocotools
docker exec mmdetection pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
docker exec mmdetection pip install git+https://github.com/cocodataset/panopticapi.git
docker exec mmdetection python -c 'import mmcv; print(mmcv.__version__)'
- run:
name: Build and install
command: |
Expand Down
2 changes: 1 addition & 1 deletion configs/detr/detr_r18_8xb2-500e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
backbone=dict(
depth=18,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
neck=dict(in_channels=[64, 128, 256, 512]))
bbox_head=dict(in_channels=512))
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ def transform(self, results: dict) -> dict:
if patch[2] == patch[0] or patch[3] == patch[1]:
continue
overlaps = boxes.overlaps(
HorizontalBoxes(patch.reshape(-1, 4)),
HorizontalBoxes(patch.reshape(-1, 4).astype(np.float32)),
boxes).numpy().reshape(-1)
if len(overlaps) > 0 and overlaps.min() < min_iou:
continue
Expand Down
44 changes: 29 additions & 15 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from torch import Tensor

from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import BaseBoxes
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, OptMultiConfig)
from ..task_modules.prior_generators import (AnchorGenerator,
anchor_inside_flags)
from ..task_modules.samplers import PseudoSampler
from ..utils import images_to_levels, multi_apply, unmap
from ..utils import (cat_boxes, get_box_tensor, images_to_levels, multi_apply,
unmap)
from .base_dense_head import BaseDenseHead


Expand Down Expand Up @@ -120,8 +122,9 @@ def _init_layers(self) -> None:
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_base_priors * self.cls_out_channels,
1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_base_priors * 4,
1)
reg_dim = self.bbox_coder.encode_size
self.conv_reg = nn.Conv2d(self.in_channels,
self.num_base_priors * reg_dim, 1)

def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Forward feature of a single scale level.
Expand Down Expand Up @@ -197,7 +200,7 @@ def get_anchors(self,
return anchor_list, valid_flag_list

def _get_targets_single(self,
flat_anchors: Tensor,
flat_anchors: Union[Tensor, BaseBoxes],
valid_flags: Tensor,
gt_instances: InstanceData,
img_meta: dict,
Expand All @@ -207,8 +210,9 @@ def _get_targets_single(self,
single image.

Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors, 4)
flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors
of the image, which are concatenated into a single tensor
or boxlist of shape (num_anchors, 4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors, ).
Expand Down Expand Up @@ -243,7 +247,7 @@ def _get_targets_single(self,
'check the image size and anchor sizes, or set '
'``allowed_border`` to -1 to skip the condition.')
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
anchors = flat_anchors[inside_flags]

pred_instances = InstanceData(priors=anchors)
assign_result = self.assigner.assign(pred_instances, gt_instances,
Expand All @@ -254,8 +258,10 @@ def _get_targets_single(self,
gt_instances)

num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox \
else self.bbox_coder.encode_size
bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim)
bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim)

# TODO: Considering saving memory, is it necessary to be long?
labels = anchors.new_full((num_valid_anchors, ),
Expand All @@ -265,12 +271,16 @@ def _get_targets_single(self,

pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# `bbox_coder.encode` accepts tensor or boxlist inputs and generates
# tensor targets. If regressing decoded boxes, the code will convert
# boxlist `pos_bbox_targets` to tensor.
if len(pos_inds) > 0:
if not self.reg_decoded_bbox:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
else:
pos_bbox_targets = sampling_result.pos_gt_bboxes
pos_bbox_targets = get_box_tensor(pos_bbox_targets)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0

Expand Down Expand Up @@ -362,7 +372,7 @@ def get_targets(self,
concat_valid_flag_list = []
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_anchor_list.append(cat_boxes(anchor_list[i]))
concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))

# compute targets for each image
Expand Down Expand Up @@ -438,15 +448,19 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
target_dim = bbox_targets.size(-1)
bbox_targets = bbox_targets.reshape(-1, target_dim)
bbox_weights = bbox_weights.reshape(-1, target_dim)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1,
self.bbox_coder.encode_size)
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, it
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, 4)
anchors = anchors.reshape(-1, anchors.size(-1))
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
bbox_pred = get_box_tensor(bbox_pred)
loss_bbox = self.loss_bbox(
bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
return loss_cls, loss_bbox
Expand Down Expand Up @@ -500,7 +514,7 @@ def loss_by_feat(
# concat all level anchors and flags to a single tensor
concat_anchor_list = []
for i in range(len(anchor_list)):
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_anchor_list.append(cat_boxes(anchor_list[i]))
all_anchor_list = images_to_levels(concat_anchor_list,
num_level_anchors)

Expand Down
19 changes: 10 additions & 9 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from mmdet.structures import SampleList
from mmdet.utils import InstanceList, OptMultiConfig
from ..test_time_augs import merge_aug_results
from ..utils import (filter_scores_and_topk, select_single_mlvl,
from ..utils import (cat_boxes, filter_scores_and_topk, get_box_tensor,
get_box_wh, scale_boxes, select_single_mlvl,
unpack_gt_instances)


Expand Down Expand Up @@ -360,7 +361,8 @@ def _predict_by_feat_single(self,

assert cls_score.size()[-2:] == bbox_pred.size()[-2:]

bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
if with_score_factors:
score_factor = score_factor.permute(1, 2,
0).reshape(-1).sigmoid()
Expand Down Expand Up @@ -401,7 +403,7 @@ def _predict_by_feat_single(self,
mlvl_score_factors.append(score_factor)

bbox_pred = torch.cat(mlvl_bbox_preds)
priors = torch.cat(mlvl_valid_priors)
priors = cat_boxes(mlvl_valid_priors)
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)

results = InstanceData()
Expand Down Expand Up @@ -452,11 +454,10 @@ def _bbox_post_process(self,
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""

if rescale:
assert img_meta.get('scale_factor') is not None
results.bboxes /= results.bboxes.new_tensor(
img_meta['scale_factor']).repeat((1, 2))
scale_factor = [1 / s for s in img_meta['scale_factor']]
results.bboxes = scale_boxes(results.bboxes, scale_factor)

if hasattr(results, 'score_factors'):
# TODO: Add sqrt operation in order to be consistent with
Expand All @@ -466,15 +467,15 @@ def _bbox_post_process(self,

# filter small size bboxes
if cfg.get('min_bbox_size', -1) >= 0:
w = results.bboxes[:, 2] - results.bboxes[:, 0]
h = results.bboxes[:, 3] - results.bboxes[:, 1]
w, h = get_box_wh(results.bboxes)
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
results = results[valid_mask]

# TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
if with_nms and results.bboxes.numel() > 0:
det_bboxes, keep_idxs = batched_nms(results.bboxes, results.scores,
bboxes = get_box_tensor(results.bboxes)
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
results.labels, cfg.nms)
results = results[keep_idxs]
# some nms would reweight the score, such as softnms
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/retina_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def _init_layers(self):
self.num_base_priors * self.cls_out_channels,
3,
padding=1)
reg_dim = self.bbox_coder.encode_size
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_base_priors * 4, 3, padding=1)
self.feat_channels, self.num_base_priors * reg_dim, 3, padding=1)

def forward_single(self, x):
"""Forward feature of a single scale level.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from mmdet.structures import DetDataSample, OptSampleList, SampleList
from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
from ..utils import samplelist_boxlist2tensor

ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
Tuple[torch.Tensor], torch.Tensor]
Expand Down Expand Up @@ -151,4 +152,5 @@ def add_pred_to_datasample(self, data_samples: SampleList,
"""
for data_sample, pred_instances in zip(data_samples, results_list):
data_sample.pred_instances = pred_instances
samplelist_boxlist2tensor(data_samples)
return data_samples
129 changes: 129 additions & 0 deletions mmdet/models/detectors/detection_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple, Union

from torch import Tensor

from mmdet.registry import MODELS
from mmdet.structures import OptSampleList, SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from .base import BaseDetector


@MODELS.register_module()
class TransformerDetector(BaseDetector):
"""Base class for Transformer-based detectors.

Transformer-based detectors use an encoder to process output features of
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
the backbone(+neck) and a decoder to pooling features into a set of
learnable queries. Each query predict a bounding box.
"""

def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
encoder_cfg: OptConfigType = None,
decoder_cfg: OptConfigType = None,
positional_encoding_cfg: OptConfigType = None,
bbox_head: OptConfigType = None,
num_query: int = 100,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
# process args
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.encoder_cfg = encoder_cfg
self.decoder_cfg = decoder_cfg
self.positional_encoding_cfg = positional_encoding_cfg
self.num_query = num_query

# init model layers
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self.bbox_head = MODELS.build(bbox_head)
self._init_layers()

def _init_layers(self) -> None:
self._init_transformer()

def _init_transformer(self) -> None:
"""1. Initialize positional_encoding
2. Initialize encoder and decoder of transformer
3. Get self.embed_dims from the transformer
4. Initialize query_embed"""
raise NotImplementedError(
'The _init_transformer should be implemented for the detector.')

# def init_weight # TODO !!!!
# def _load_from_state_dict # TODO !!!!
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved

def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
img_feats = self.extract_feat(batch_inputs)
seq_feats = self.forward_pretransformer(img_feats, batch_data_samples)
outs_dec = self.forward_transformer(
**seq_feats, query_embed=self.query_embedding.weight)
losses = self.bbox_head.loss(outs_dec, batch_data_samples)
return losses

def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
img_feats = self.extract_feat(batch_inputs)
seq_feats = self.forward_pretransformer(img_feats, batch_data_samples)
outs_dec = self.forward_transformer(
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
**seq_feats, query_embed=self.query_embedding.weight)
results_list = self.bbox_head.predict(
outs_dec, batch_data_samples, rescale=rescale)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples

def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
img_feats = self.extract_feat(batch_inputs)
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
seq_feats = self.forward_pretransformer(img_feats, batch_data_samples)
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
outs_dec = self.forward_transformer(
**seq_feats, query_embed=self.query_embedding.weight)
results = self.bbox_head.forward(outs_dec)
return results

def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.

Args:
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).

Returns:
tuple[Tensor]: Multi-level features that may have
different resolutions.
"""
x = self.backbone(batch_inputs)
if self.with_neck:
x = self.neck(x)
return x

Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
def forward_pretransformer(
self,
img_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None) -> Dict[str, Tensor]:
"""1. Get batch padding mask.
2. Convert image feature maps to sequential features.
3. Get image positional embedding of features."""
raise NotImplementedError(
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
'The forward_pretransformer should be implemented '
'for the detector.')

Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
def forward_transformer(self, **kwargs) -> Tuple[Tensor]:
"""Process sequential features with transformer."""
raise NotImplementedError(
Li-Qingyun marked this conversation as resolved.
Show resolved Hide resolved
'The forward_transformer should be implemented for the detector.')
Loading