diff --git a/configs/distill/fgd/README.md b/configs/distill/fgd/README.md new file mode 100644 index 000000000..303f49f76 --- /dev/null +++ b/configs/distill/fgd/README.md @@ -0,0 +1,24 @@ +# FGD +> [Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837) + + +## Abstract + +Knowledge distillation has been applied to image classification successfully. However, object detection is much more sophisticated and most knowledge distillation methods have failed on it. In this paper, we point out that in object detection, the features of the teacher and student vary greatly in different areas, especially in the foreground and background. If we distill them equally, the uneven differences between feature maps will negatively affect the distillation. Thus, we propose Focal and Global Distillation (FGD). Focal distillation separates the foreground and background, forcing the student to focus on the teacher's critical pixels and channels. Global distillation rebuilds the relation between different pixels and transfers it from teachers to students, compensating for missing global information in focal distillation. As our method only needs to calculate the loss on the feature map, FGD can be applied to various detectors. We experiment on various detectors with different backbones and the results show that the student detector achieves excellent mAP improvement. For example, ResNet-50 based RetinaNet, Faster RCNN, RepPoints and Mask RCNN with our distillation method achieve 40.7%, 42.0%, 42.0% and 42.1% mAP on COCO2017, which are 3.3, 3.6, 3.4 and 2.9 higher than the baseline, respectively. + + +![pipeline](/docs/en/imgs/model_zoo/fgd/pipeline.png) + + + + +## Citation + +```latex +@article{yang2021focal, + title={Focal and Global Knowledge Distillation for Detectors}, + author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun}, + journal={arXiv preprint arXiv:2111.11837}, + year={2021} +} +``` diff --git a/configs/distill/fgd/fgd_gfl_r101_distill_gfl_r50_fpn_1x_coco.py b/configs/distill/fgd/fgd_gfl_r101_distill_gfl_r50_fpn_1x_coco.py new file mode 100644 index 000000000..10e99f5e3 --- /dev/null +++ b/configs/distill/fgd/fgd_gfl_r101_distill_gfl_r50_fpn_1x_coco.py @@ -0,0 +1,209 @@ +_base_ = [ + '../../_base_/datasets/mmdet/coco_detection.py', + '../../_base_/schedules/mmdet/schedule_1x.py', + '../../_base_/mmdet_runtime.py' +] + +# model settings +t_weight = 'https://download.openmmlab.com/mmdetection/v2.0/' + \ + 'gfl/gfl_r101_fpn_mstrain_2x_coco/' + \ + 'gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' +student = dict( + type='mmdet.GFL', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5, + init_cfg=dict(type='Pretrained', prefix='neck', checkpoint=t_weight)), + bbox_head=dict( + type='GFLHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25), + reg_max=16, + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + init_cfg=dict( + type='Pretrained', prefix='bbox_head', checkpoint=t_weight)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +teacher = dict( + type='mmdet.GFL', + init_cfg=dict(type='Pretrained', checkpoint=t_weight), + backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=None), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='GFLHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25), + reg_max=16, + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# algorithm setting +temp = 0.5 +alpha_fgd = 0.001 +beta_fgd = 0.0005 +gamma_fgd = 0.0005 +lambda_fgd = 0.000005 +algorithm = dict( + type='GeneralDistill', + architecture=dict( + type='MMDetArchitecture', + model=student, + ), + distiller=dict( + type='SingleTeacherDistiller', + teacher=teacher, + teacher_trainable=False, + components=[ + dict( + student_module='neck.fpn_convs.0.conv', + teacher_module='neck.fpn_convs.0.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_0', + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.1.conv', + teacher_module='neck.fpn_convs.1.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_1', + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.2.conv', + teacher_module='neck.fpn_convs.2.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_2', + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.3.conv', + teacher_module='neck.fpn_convs.3.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_3', + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + dict( + student_module='neck.fpn_convs.4.conv', + teacher_module='neck.fpn_convs.4.conv', + losses=[ + dict( + type='FGDLoss', + name='loss_fgd_4', + alpha_fgd=alpha_fgd, + beta_fgd=beta_fgd, + gamma_fgd=gamma_fgd, + lambda_fgd=lambda_fgd, + ) + ]), + ]), +) + +find_unused_parameters = True + +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict( + _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) diff --git a/docs/en/imgs/model_zoo/fgd/pipeline.png b/docs/en/imgs/model_zoo/fgd/pipeline.png new file mode 100644 index 000000000..4b3f38396 Binary files /dev/null and b/docs/en/imgs/model_zoo/fgd/pipeline.png differ diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index c161c5684..607aa8484 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cwd import ChannelWiseDivergence +from .fgd import FGDLoss from .kl_divergence import KLDivergence from .weighted_soft_label_distillation import WSLD -__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD'] +__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD', 'FGDLoss'] diff --git a/mmrazor/models/losses/fgd.py b/mmrazor/models/losses/fgd.py new file mode 100644 index 000000000..afd51ca61 --- /dev/null +++ b/mmrazor/models/losses/fgd.py @@ -0,0 +1,225 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import constant_init, kaiming_init + +from ..builder import LOSSES + + +@LOSSES.register_module() +class FGDLoss(nn.Module): + """PyTorch version of 'Focal and Global Knowledge Distillation for + Detectors'. + + + + Args: + student_channels(int): Number of channels in the student's feature map. + teacher_channels(int): Number of channels in the teacher's feature map. + temp (float, optional): Temperature coefficient. Defaults to 0.5. + name (str): the loss name of the layer + alpha_fgd (float, optional): Weight of fg_loss. + beta_fgd (float, optional): Weight of bg_loss. + gamma_fgd (float, optional): Weight of mask_loss. + lambda_fgd (float, optional): Weight of relation_loss. + """ + + def __init__( + self, + student_channels, + teacher_channels, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005, + ): + super(FGDLoss, self).__init__() + self.temp = temp + self.alpha_fgd = alpha_fgd + self.beta_fgd = beta_fgd + self.gamma_fgd = gamma_fgd + self.lambda_fgd = lambda_fgd + + self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1) + self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1) + self.channel_add_conv_s = nn.Sequential( + nn.Conv2d(teacher_channels, teacher_channels // 2, kernel_size=1), + nn.LayerNorm([teacher_channels // 2, 1, 1]), nn.ReLU(inplace=True), + nn.Conv2d(teacher_channels // 2, teacher_channels, kernel_size=1)) + self.channel_add_conv_t = nn.Sequential( + nn.Conv2d(teacher_channels, teacher_channels // 2, kernel_size=1), + nn.LayerNorm([teacher_channels // 2, 1, 1]), nn.ReLU(inplace=True), + nn.Conv2d(teacher_channels // 2, teacher_channels, kernel_size=1)) + + self.reset_parameters() + + def forward(self, preds_S, preds_T): + """Forward function. + + Args: + preds_S(Tensor): Bs*C*H*W, student's feature map + preds_T(Tensor): Bs*C*H*W, teacher's feature map + gt_bboxes(tuple): Bs*[nt*4], (tl_x, tl_y, br_x, br_y) + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + """ + assert preds_S.shape[-2:] == preds_T.shape[-2:] + N, C, H, W = preds_S.shape + gt_bboxes = self.current_data['gt_boxxes'] + metas = self.current_data['img_metas'] + + S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp) + S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp) + + M_fg = torch.zeros_like(S_attention_t) + M_bg = torch.ones_like(S_attention_t) + wmin, wmax, hmin, hmax = [], [], [], [] + for i in range(N): + new_boxx = torch.ones_like(gt_bboxes[i]) + new_boxx[:, 0] = gt_bboxes[i][:, 0] / metas[i]['img_shape'][1] * W + new_boxx[:, 2] = gt_bboxes[i][:, 2] / metas[i]['img_shape'][1] * W + new_boxx[:, 1] = gt_bboxes[i][:, 1] / metas[i]['img_shape'][0] * H + new_boxx[:, 3] = gt_bboxes[i][:, 3] / metas[i]['img_shape'][0] * H + + wmin.append(torch.floor(new_boxx[:, 0]).int()) + wmax.append(torch.ceil(new_boxx[:, 2]).int()) + hmin.append(torch.floor(new_boxx[:, 1]).int()) + hmax.append(torch.ceil(new_boxx[:, 3]).int()) + + height = hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1) + width = wmax[i].view(1, -1) + 1 - wmin[i].view(1, -1) + area = 1.0 / height / width + + for j in range(len(gt_bboxes[i])): + M_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \ + torch.maximum(M_fg[i][hmin[i][j]:hmax[i][j]+1, + wmin[i][j]:wmax[i][j]+1], area[0][j]) + + M_bg[i] = torch.where(M_fg[i] > 0, 0, 1) + if torch.sum(M_bg[i]): + M_bg[i] /= torch.sum(M_bg[i]) + + fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, M_fg, M_bg, + C_attention_s, C_attention_t, + S_attention_s, S_attention_t) + mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, + S_attention_s, S_attention_t) + rela_loss = self.get_rela_loss(preds_S, preds_T) + + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + + return loss + + def get_attention(self, preds, temp): + """ preds: Bs*C*H*W """ + N, C, H, W = preds.shape + + value = torch.abs(preds) + # Bs*W*H + fea_map = value.mean(axis=1, keepdim=True) + S_attention = (H * W * F.softmax( + (fea_map / temp).view(N, -1), dim=1)).view(N, H, W) + + # Bs*C + channel_map = value.mean( + axis=2, keepdim=False).mean( + axis=2, keepdim=False) + C_attention = C * F.softmax(channel_map / temp, dim=1) + + return S_attention, C_attention + + def get_fea_loss(self, preds_S, preds_T, M_fg, M_bg, C_s, C_t, S_s, S_t): + loss_mse = nn.MSELoss(reduction='sum') + + M_fg = M_fg.unsqueeze(dim=1) + M_bg = M_bg.unsqueeze(dim=1) + + C_t = C_t.unsqueeze(dim=-1) + C_t = C_t.unsqueeze(dim=-1) + + S_t = S_t.unsqueeze(dim=1) + + fea_t = torch.mul(preds_T, torch.sqrt(S_t)) + fea_t = torch.mul(fea_t, torch.sqrt(C_t)) + fg_fea_t = torch.mul(fea_t, torch.sqrt(M_fg)) + bg_fea_t = torch.mul(fea_t, torch.sqrt(M_bg)) + + fea_s = torch.mul(preds_S, torch.sqrt(S_t)) + fea_s = torch.mul(fea_s, torch.sqrt(C_t)) + fg_fea_s = torch.mul(fea_s, torch.sqrt(M_fg)) + bg_fea_s = torch.mul(fea_s, torch.sqrt(M_bg)) + + fg_loss = loss_mse(fg_fea_s, fg_fea_t) / len(M_fg) + bg_loss = loss_mse(bg_fea_s, bg_fea_t) / len(M_bg) + + return fg_loss, bg_loss + + def get_mask_loss(self, C_s, C_t, S_s, S_t): + + mask_loss = torch.sum(torch.abs( + (C_s - C_t))) / len(C_s) + torch.sum(torch.abs( + (S_s - S_t))) / len(S_s) + + return mask_loss + + def spatial_pool(self, x, in_type): + batch, channel, width, height = x.size() + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + if in_type == 0: + context_mask = self.conv_mask_s(x) + else: + context_mask = self.conv_mask_t(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = F.softmax(context_mask, dim=2) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + + return context + + def get_rela_loss(self, preds_S, preds_T): + loss_mse = nn.MSELoss(reduction='sum') + + context_s = self.spatial_pool(preds_S, 0) + context_t = self.spatial_pool(preds_T, 1) + + out_s = preds_S + out_t = preds_T + + channel_add_s = self.channel_add_conv_s(context_s) + out_s = out_s + channel_add_s + + channel_add_t = self.channel_add_conv_t(context_t) + out_t = out_t + channel_add_t + + rela_loss = loss_mse(out_s, out_t) / len(out_s) + + return rela_loss + + def last_zero_init(self, m): + if isinstance(m, nn.Sequential): + constant_init(m[-1], val=0) + else: + constant_init(m, val=0) + + def reset_parameters(self): + kaiming_init(self.conv_mask_s, mode='fan_in') + kaiming_init(self.conv_mask_t, mode='fan_in') + self.conv_mask_s.inited = True + self.conv_mask_t.inited = True + + self.last_zero_init(self.channel_add_conv_s) + self.last_zero_init(self.channel_add_conv_t)