Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add FBKD algorithm and torch_connectors #248

Merged
merged 4 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/distill/mmcls/dfad/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dfad_logits_r34_r18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |

## Citation

Expand Down
6 changes: 3 additions & 3 deletions configs/distill/mmcls/zskt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ Performing knowledge transfer from a large teacher network to a smaller student

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./zskt_backbone_logits_r34_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |

## Citation

Expand Down
37 changes: 37 additions & 0 deletions configs/distill/mmdet/fbkd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS (FBKD)

> [IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS](https://openreview.net/pdf?id=uKhGRvM8QNH)

<!-- [ALGORITHM] -->

## Abstract

Knowledge distillation, in which a student model is trained to mimic a teacher model, has been proved as an effective technique for model compression and model accuracy boosting. However, most knowledge distillation methods, designed for image classification, have failed on more challenging tasks, such as object detection. In this paper, we suggest that the failure of knowledge distillation on object detection is mainly caused by two reasons: (1) the imbalance between pixels of foreground and background and (2) lack of distillation on the relation between different pixels. Observing the above reasons, we propose attention-guided distillation and non-local distillation to address the two problems, respectively. Attention-guided distillation is proposed to find the crucial pixels of foreground objects with attention mechanism and then make the students take more effort to learn their features. Non-local distillation is proposed to enable students to learn not only the feature of an individual pixel but also the relation between different pixels captured by non-local modules. Experiments show that our methods achieve excellent AP improvements on both one-stage and two-stage, both anchor-based and anchor-free detectors. For example, Faster RCNN (ResNet101 backbone) with our distillation achieves 43.9 AP on COCO2017, which is 4.1 higher than the baseline.

![pipeline](/docs/en/imgs/model_zoo/fbkd/pipeline.png)

## Results and models

### Detection

| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :--------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| neck | COCO | [fasterrcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [fasterrcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.1 | 39.4 | 37.8 | [config](./fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{DBLP:conf/iclr/ZhangM21,
author = {Linfeng Zhang and Kaisheng Ma},
title = {Improve Object Detection with Feature-based Knowledge Distillation:
Towards Accurate and Efficient Detectors},
booktitle = {9th International Conference on Learning Representations, {ICLR} 2021,
Virtual Event, Austria, May 3-7, 2021},
publisher = {OpenReview.net},
year = {2021},
url = {https://openreview.net/forum?id=uKhGRvM8QNH},
timestamp = {Wed, 23 Jun 2021 17:36:39 +0200},
biburl = {https://dblp.org/rec/conf/iclr/ZhangM21.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
125 changes: 125 additions & 0 deletions configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
_base_ = [
'mmdet::_base_/datasets/coco_detection.py',
'mmdet::_base_/schedules/schedule_1x.py',
'mmdet::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
architecture=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
pretrained=True),
teacher=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py',
pretrained=False),
teacher_ckpt='faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
neck_s0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
neck_s1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
neck_s2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
neck_s3=dict(type='ModuleOutputs',
source='neck.fpn_convs.3.conv')),
teacher_recorders=dict(
neck_s0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
neck_s1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
neck_s2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
neck_s3=dict(type='ModuleOutputs',
source='neck.fpn_convs.3.conv')),
distill_losses=dict(
loss_s0=dict(type='FBKDLoss'),
loss_s1=dict(type='FBKDLoss'),
loss_s2=dict(type='FBKDLoss'),
loss_s3=dict(type='FBKDLoss')),
connectors=dict(
loss_s0_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=8),
loss_s0_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=8),
loss_s1_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=4),
loss_s1_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=4),
loss_s2_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
mode='dot_product',
sub_sample=True),
loss_s2_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
mode='dot_product',
sub_sample=True),
loss_s3_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
mode='dot_product',
sub_sample=True),
loss_s3_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
mode='dot_product',
sub_sample=True)),
loss_forward_mappings=dict(
loss_s0=dict(
s_input=dict(
from_student=True,
recorder='neck_s0',
connector='loss_s0_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s0',
connector='loss_s0_tfeat')),
loss_s1=dict(
s_input=dict(
from_student=True,
recorder='neck_s1',
connector='loss_s1_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s1',
connector='loss_s1_tfeat')),
loss_s2=dict(
s_input=dict(
from_student=True,
recorder='neck_s2',
connector='loss_s2_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s2',
connector='loss_s2_tfeat')),
loss_s3=dict(
s_input=dict(
from_student=True,
recorder='neck_s3',
connector='loss_s3_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s3',
connector='loss_s3_tfeat')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Binary file added docs/en/imgs/model_zoo/fbkd/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,11 @@
from .byot_connector import BYOTConnector
from .convmodule_connector import ConvModuleConncetor
from .factor_transfer_connectors import Paraphraser, Translator
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
from .torch_connector import TorchFunctionalConnector, TorchNNConnector

__all__ = ['ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector']
__all__ = [
'ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector',
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
'TorchNNConnector'
]
6 changes: 4 additions & 2 deletions mmrazor/models/architectures/connectors/base_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Tuple, Union

import torch
from mmengine.model import BaseModule
Expand Down Expand Up @@ -32,7 +32,9 @@ def forward(self, feature: torch.Tensor) -> torch.Tensor:
return self.forward_train(feature)

@abstractmethod
def forward_train(self, feature) -> torch.Tensor:
def forward_train(
self, feature: torch.Tensor
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
"""Abstract train computation.

Args:
Expand Down
2 changes: 1 addition & 1 deletion mmrazor/models/architectures/connectors/byot_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.scala = nn.Sequential(*scala)
self.fc = nn.Linear(out_channel * expansion, num_classes)

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
def forward_train(self, feature: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Forward computation.

Args:
Expand Down
Loading