-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature]Support Relational Knowledge Distillation (#127)
* add rkd * add rkd pytest * add rkd configs * fix readme * fix rkd * split rkd loss to distance-wise and angle-wise losses * rename rkd losses * add rkd metaflie * add rkd related links * rename rkd metafile and add to model index * delete cifar100 Co-authored-by: caoweihan <caoweihan@sensetime.com> Co-authored-by: pppppM <gjf_mail@126.com>
- Loading branch information
1 parent
f9920a4
commit de4dd13
Showing
8 changed files
with
417 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# RKD | ||
|
||
|
||
|
||
> [Relational Knowledge Distillation](https://arxiv.org/abs/1904.05068) | ||
<!-- [ALGORITHM] --> | ||
## Abstract | ||
Knowledge distillation aims at transferring knowledge acquired | ||
in one model (a teacher) to another model (a student) that is | ||
typically smaller. Previous approaches can be expressed as | ||
a form of training the student to mimic output activations of | ||
individual data examples represented by the teacher. We introduce | ||
a novel approach, dubbed relational knowledge distillation (RKD), | ||
that transfers mutual relations of data examples instead. | ||
For concrete realizations of RKD, we propose distance-wise and | ||
angle-wise distillation losses that penalize structural differences | ||
in relations. Experiments conducted on different tasks show that the | ||
proposed method improves educated student models with a significant margin. | ||
In particular for metric learning, it allows students to outperform their | ||
teachers' performance, achieving the state of the arts on standard benchmark datasets. | ||
|
||
![pipeline](/docs/en/imgs/model_zoo/rkd/pipeline.png) | ||
|
||
## Results and models | ||
### Classification | ||
|Location|Dataset|Teacher|Student|Acc|Acc(T)|Acc(S)|Config | Download | | ||
:--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------| | ||
| neck |ImageNet|[resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py)|[resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py)| 70.23 | 73.62 | 69.90 |[config](./rkd_neck_resnet34_resnet18_8xb32_in1k.py)|[teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) |[model](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth) | [log](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_20220312_130419.log.json)| | ||
|
||
|
||
|
||
## Citation | ||
```latex | ||
@inproceedings{park2019relational, | ||
title={Relational knowledge distillation}, | ||
author={Park, Wonpyo and Kim, Dongju and Lu, Yan and Cho, Minsu}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={3967--3976}, | ||
year={2019} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
Collections: | ||
- Name: RKD | ||
Metadata: | ||
Training Data: | ||
- ImageNet-1k | ||
Paper: | ||
URL: https://arxiv.org/abs/1904.05068 | ||
Title: Relational Knowledge Distillation | ||
README: configs/distill/rkd/README.md | ||
Code: | ||
URL: https://github.com/open-mmlab/mmrazor/blob/v0.3.0/mmrazor/models/losses/relation_kd.py | ||
Version: v0.3.0 | ||
Converted From: | ||
Code: https://github.com/lenscloth/RKD | ||
Models: | ||
- Name: rkd_neck_resnet34_resnet18_8xb32_in1k | ||
In Collection: RKD | ||
Metadata: | ||
Location: neck | ||
Student: R-18 | ||
Teacher: R-34 | ||
Teacher Checkpoint: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth | ||
Results: | ||
- Task: Image Classification | ||
Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 70.23 | ||
Top 1 Accuracy:(S): 69.90 | ||
Top 1 Accuracy:(T): 73.62 | ||
Config: configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py | ||
Weights: https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth |
79 changes: 79 additions & 0 deletions
79
configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
_base_ = [ | ||
'../../_base_/datasets/mmcls/imagenet_bs32.py', | ||
'../../_base_/schedules/mmcls/imagenet_bs256.py', | ||
'../../_base_/mmcls_runtime.py' | ||
] | ||
|
||
# model settings | ||
student = dict( | ||
type='mmcls.ImageClassifier', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=18, | ||
num_stages=4, | ||
out_indices=(3, ), | ||
style='pytorch'), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=512, | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) | ||
|
||
# teacher settings | ||
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501 | ||
|
||
teacher = dict( | ||
type='mmcls.ImageClassifier', | ||
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt), | ||
backbone=dict( | ||
type='ResNet', | ||
depth=34, | ||
num_stages=4, | ||
out_indices=(3, ), | ||
style='pytorch'), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=512, | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) | ||
|
||
# algorithm setting | ||
algorithm = dict( | ||
type='GeneralDistill', | ||
architecture=dict( | ||
type='MMClsArchitecture', | ||
model=student, | ||
), | ||
with_student_loss=True, | ||
with_teacher_loss=False, | ||
distiller=dict( | ||
type='SingleTeacherDistiller', | ||
teacher=teacher, | ||
teacher_trainable=False, | ||
teacher_norm_eval=True, | ||
components=[ | ||
dict( | ||
student_module='neck.gap', | ||
teacher_module='neck.gap', | ||
losses=[ | ||
dict( | ||
type='DistanceWiseRKD', | ||
name='distance_wise_loss', | ||
loss_weight=25.0, | ||
with_l2_norm=True), | ||
dict( | ||
type='AngleWiseRKD', | ||
name='angle_wise_loss', | ||
loss_weight=50.0, | ||
with_l2_norm=True), | ||
]) | ||
]), | ||
) | ||
|
||
find_unused_parameters = True |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .cwd import ChannelWiseDivergence | ||
from .kl_divergence import KLDivergence | ||
from .relational_kd import AngleWiseRKD, DistanceWiseRKD | ||
from .weighted_soft_label_distillation import WSLD | ||
|
||
__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD'] | ||
__all__ = [ | ||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', | ||
'WSLD' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from ..builder import LOSSES | ||
|
||
|
||
def euclidean_distance(pred, squared=False, eps=1e-12): | ||
"""Calculate the Euclidean distance between the two examples in the output | ||
representation space. | ||
Args: | ||
pred (torch.Tensor): The prediction of the teacher or student with | ||
shape (N, C). | ||
squared (bool): Whether to calculate the squared Euclidean | ||
distance. Defaults to False. | ||
eps (float): The minimum Euclidean distance between the two | ||
examples. Defaults to 1e-12. | ||
""" | ||
pred_square = pred.pow(2).sum(dim=-1) # (N, ) | ||
prod = torch.mm(pred, pred.t()) # (N, N) | ||
distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) - | ||
2 * prod).clamp(min=eps) # (N, N) | ||
|
||
if not squared: | ||
distance = distance.sqrt() | ||
|
||
distance = distance.clone() | ||
distance[range(len(prod)), range(len(prod))] = 0 | ||
return distance | ||
|
||
|
||
def angle(pred): | ||
"""Calculate the angle-wise relational potential which measures the angle | ||
formed by the three examples in the output representation space. | ||
Args: | ||
pred (torch.Tensor): The prediction of the teacher or student with | ||
shape (N, C). | ||
""" | ||
pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1) # (N, N, C) | ||
norm_pred_vec = F.normalize(pred_vec, p=2, dim=2) | ||
angle = torch.bmm(norm_pred_vec, | ||
norm_pred_vec.transpose(1, 2)).view(-1) # (N*N*N, ) | ||
return angle | ||
|
||
|
||
@LOSSES.register_module() | ||
class DistanceWiseRKD(nn.Module): | ||
"""PyTorch version of distance-wise loss of `Relational Knowledge | ||
Distillation. | ||
<https://arxiv.org/abs/1904.05068>`_. | ||
Args: | ||
loss_weight (float): Weight of distance-wise distillation loss. | ||
Defaults to 25.0. | ||
with_l2_norm (bool): Whether to normalize the model predictions before | ||
calculating the loss. Defaults to True. | ||
""" | ||
|
||
def __init__(self, loss_weight=25.0, with_l2_norm=True): | ||
super(DistanceWiseRKD, self).__init__() | ||
|
||
self.loss_weight = loss_weight | ||
self.with_l2_norm = with_l2_norm | ||
|
||
def distance_loss(self, preds_S, preds_T): | ||
"""Calculate distance-wise distillation loss.""" | ||
d_T = euclidean_distance(preds_T, squared=False) | ||
# mean_d_T is a normalization factor for distance | ||
mean_d_T = d_T[d_T > 0].mean() | ||
d_T = d_T / mean_d_T | ||
|
||
d_S = euclidean_distance(preds_S, squared=False) | ||
mean_d_S = d_S[d_S > 0].mean() | ||
d_S = d_S / mean_d_S | ||
|
||
return F.smooth_l1_loss(d_S, d_T) | ||
|
||
def forward(self, preds_S, preds_T): | ||
"""Forward computation. | ||
Args: | ||
preds_S (torch.Tensor): The student model prediction with | ||
shape (N, C, H, W) or shape (N, C). | ||
preds_T (torch.Tensor): The teacher model prediction with | ||
shape (N, C, H, W) or shape (N, C). | ||
Return: | ||
torch.Tensor: The calculated loss value. | ||
""" | ||
preds_S = preds_S.view(preds_S.shape[0], -1) | ||
preds_T = preds_T.view(preds_T.shape[0], -1) | ||
if self.with_l2_norm: | ||
preds_S = F.normalize(preds_S, p=2, dim=1) | ||
preds_T = F.normalize(preds_T, p=2, dim=1) | ||
|
||
loss = self.distance_loss(preds_S, preds_T) * self.loss_weight | ||
|
||
return loss | ||
|
||
|
||
@LOSSES.register_module() | ||
class AngleWiseRKD(nn.Module): | ||
"""PyTorch version of angle-wise loss of `Relational Knowledge | ||
Distillation. | ||
<https://arxiv.org/abs/1904.05068>`_. | ||
Args: | ||
loss_weight (float): Weight of angle-wise distillation loss. | ||
Defaults to 50.0. | ||
with_l2_norm (bool): Whether to normalize the model predictions before | ||
calculating the loss. Defaults to True. | ||
""" | ||
|
||
def __init__(self, loss_weight=50.0, with_l2_norm=True): | ||
super(AngleWiseRKD, self).__init__() | ||
|
||
self.loss_weight = loss_weight | ||
self.with_l2_norm = with_l2_norm | ||
|
||
def angle_loss(self, preds_S, preds_T): | ||
"""Calculate the angle-wise distillation loss.""" | ||
angle_T = angle(preds_T) | ||
angle_S = angle(preds_S) | ||
return F.smooth_l1_loss(angle_S, angle_T) | ||
|
||
def forward(self, preds_S, preds_T): | ||
"""Forward computation. | ||
Args: | ||
preds_S (torch.Tensor): The student model prediction with | ||
shape (N, C, H, W) or shape (N, C). | ||
preds_T (torch.Tensor): The teacher model prediction with | ||
shape (N, C, H, W) or shape (N, C). | ||
Return: | ||
torch.Tensor: The calculated loss value. | ||
""" | ||
preds_S = preds_S.view(preds_S.shape[0], -1) | ||
preds_T = preds_T.view(preds_T.shape[0], -1) | ||
if self.with_l2_norm: | ||
preds_S = F.normalize(preds_S, p=2, dim=-1) | ||
preds_T = F.normalize(preds_T, p=2, dim=-1) | ||
|
||
loss = self.angle_loss(preds_S, preds_T) * self.loss_weight | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.