Skip to content

Commit

Permalink
[Feature]Support Relational Knowledge Distillation (#127)
Browse files Browse the repository at this point in the history
* add rkd

* add rkd pytest

* add rkd configs

* fix readme

* fix rkd

* split rkd loss to distance-wise and angle-wise losses

* rename rkd losses

* add rkd metaflie

* add rkd related links

* rename rkd metafile and add to model index

* delete cifar100

Co-authored-by: caoweihan <caoweihan@sensetime.com>
Co-authored-by: pppppM <gjf_mail@126.com>
  • Loading branch information
3 people authored Apr 2, 2022
1 parent f9920a4 commit de4dd13
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 1 deletion.
42 changes: 42 additions & 0 deletions configs/distill/rkd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# RKD



> [Relational Knowledge Distillation](https://arxiv.org/abs/1904.05068)
<!-- [ALGORITHM] -->
## Abstract
Knowledge distillation aims at transferring knowledge acquired
in one model (a teacher) to another model (a student) that is
typically smaller. Previous approaches can be expressed as
a form of training the student to mimic output activations of
individual data examples represented by the teacher. We introduce
a novel approach, dubbed relational knowledge distillation (RKD),
that transfers mutual relations of data examples instead.
For concrete realizations of RKD, we propose distance-wise and
angle-wise distillation losses that penalize structural differences
in relations. Experiments conducted on different tasks show that the
proposed method improves educated student models with a significant margin.
In particular for metric learning, it allows students to outperform their
teachers' performance, achieving the state of the arts on standard benchmark datasets.

![pipeline](/docs/en/imgs/model_zoo/rkd/pipeline.png)

## Results and models
### Classification
|Location|Dataset|Teacher|Student|Acc|Acc(T)|Acc(S)|Config | Download |
:--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------|
| neck |ImageNet|[resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py)|[resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py)| 70.23 | 73.62 | 69.90 |[config](./rkd_neck_resnet34_resnet18_8xb32_in1k.py)|[teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) &#124;[model](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth) &#124; [log](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_20220312_130419.log.json)|



## Citation
```latex
@inproceedings{park2019relational,
title={Relational knowledge distillation},
author={Park, Wonpyo and Kim, Dongju and Lu, Yan and Cho, Minsu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={3967--3976},
year={2019}
}
```
31 changes: 31 additions & 0 deletions configs/distill/rkd/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Collections:
- Name: RKD
Metadata:
Training Data:
- ImageNet-1k
Paper:
URL: https://arxiv.org/abs/1904.05068
Title: Relational Knowledge Distillation
README: configs/distill/rkd/README.md
Code:
URL: https://github.com/open-mmlab/mmrazor/blob/v0.3.0/mmrazor/models/losses/relation_kd.py
Version: v0.3.0
Converted From:
Code: https://github.com/lenscloth/RKD
Models:
- Name: rkd_neck_resnet34_resnet18_8xb32_in1k
In Collection: RKD
Metadata:
Location: neck
Student: R-18
Teacher: R-34
Teacher Checkpoint: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 70.23
Top 1 Accuracy:(S): 69.90
Top 1 Accuracy:(T): 73.62
Config: configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth
79 changes: 79 additions & 0 deletions configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
_base_ = [
'../../_base_/datasets/mmcls/imagenet_bs32.py',
'../../_base_/schedules/mmcls/imagenet_bs256.py',
'../../_base_/mmcls_runtime.py'
]

# model settings
student = dict(
type='mmcls.ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

# teacher settings
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501

teacher = dict(
type='mmcls.ImageClassifier',
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
backbone=dict(
type='ResNet',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

# algorithm setting
algorithm = dict(
type='GeneralDistill',
architecture=dict(
type='MMClsArchitecture',
model=student,
),
with_student_loss=True,
with_teacher_loss=False,
distiller=dict(
type='SingleTeacherDistiller',
teacher=teacher,
teacher_trainable=False,
teacher_norm_eval=True,
components=[
dict(
student_module='neck.gap',
teacher_module='neck.gap',
losses=[
dict(
type='DistanceWiseRKD',
name='distance_wise_loss',
loss_weight=25.0,
with_l2_norm=True),
dict(
type='AngleWiseRKD',
name='angle_wise_loss',
loss_weight=50.0,
with_l2_norm=True),
])
]),
)

find_unused_parameters = True
Binary file added docs/en/imgs/model_zoo/rkd/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cwd import ChannelWiseDivergence
from .kl_divergence import KLDivergence
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
from .weighted_soft_label_distillation import WSLD

__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD']
__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD'
]
149 changes: 149 additions & 0 deletions mmrazor/models/losses/relational_kd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


def euclidean_distance(pred, squared=False, eps=1e-12):
"""Calculate the Euclidean distance between the two examples in the output
representation space.
Args:
pred (torch.Tensor): The prediction of the teacher or student with
shape (N, C).
squared (bool): Whether to calculate the squared Euclidean
distance. Defaults to False.
eps (float): The minimum Euclidean distance between the two
examples. Defaults to 1e-12.
"""
pred_square = pred.pow(2).sum(dim=-1) # (N, )
prod = torch.mm(pred, pred.t()) # (N, N)
distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) -
2 * prod).clamp(min=eps) # (N, N)

if not squared:
distance = distance.sqrt()

distance = distance.clone()
distance[range(len(prod)), range(len(prod))] = 0
return distance


def angle(pred):
"""Calculate the angle-wise relational potential which measures the angle
formed by the three examples in the output representation space.
Args:
pred (torch.Tensor): The prediction of the teacher or student with
shape (N, C).
"""
pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1) # (N, N, C)
norm_pred_vec = F.normalize(pred_vec, p=2, dim=2)
angle = torch.bmm(norm_pred_vec,
norm_pred_vec.transpose(1, 2)).view(-1) # (N*N*N, )
return angle


@LOSSES.register_module()
class DistanceWiseRKD(nn.Module):
"""PyTorch version of distance-wise loss of `Relational Knowledge
Distillation.
<https://arxiv.org/abs/1904.05068>`_.
Args:
loss_weight (float): Weight of distance-wise distillation loss.
Defaults to 25.0.
with_l2_norm (bool): Whether to normalize the model predictions before
calculating the loss. Defaults to True.
"""

def __init__(self, loss_weight=25.0, with_l2_norm=True):
super(DistanceWiseRKD, self).__init__()

self.loss_weight = loss_weight
self.with_l2_norm = with_l2_norm

def distance_loss(self, preds_S, preds_T):
"""Calculate distance-wise distillation loss."""
d_T = euclidean_distance(preds_T, squared=False)
# mean_d_T is a normalization factor for distance
mean_d_T = d_T[d_T > 0].mean()
d_T = d_T / mean_d_T

d_S = euclidean_distance(preds_S, squared=False)
mean_d_S = d_S[d_S > 0].mean()
d_S = d_S / mean_d_S

return F.smooth_l1_loss(d_S, d_T)

def forward(self, preds_S, preds_T):
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, H, W) or shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, H, W) or shape (N, C).
Return:
torch.Tensor: The calculated loss value.
"""
preds_S = preds_S.view(preds_S.shape[0], -1)
preds_T = preds_T.view(preds_T.shape[0], -1)
if self.with_l2_norm:
preds_S = F.normalize(preds_S, p=2, dim=1)
preds_T = F.normalize(preds_T, p=2, dim=1)

loss = self.distance_loss(preds_S, preds_T) * self.loss_weight

return loss


@LOSSES.register_module()
class AngleWiseRKD(nn.Module):
"""PyTorch version of angle-wise loss of `Relational Knowledge
Distillation.
<https://arxiv.org/abs/1904.05068>`_.
Args:
loss_weight (float): Weight of angle-wise distillation loss.
Defaults to 50.0.
with_l2_norm (bool): Whether to normalize the model predictions before
calculating the loss. Defaults to True.
"""

def __init__(self, loss_weight=50.0, with_l2_norm=True):
super(AngleWiseRKD, self).__init__()

self.loss_weight = loss_weight
self.with_l2_norm = with_l2_norm

def angle_loss(self, preds_S, preds_T):
"""Calculate the angle-wise distillation loss."""
angle_T = angle(preds_T)
angle_S = angle(preds_S)
return F.smooth_l1_loss(angle_S, angle_T)

def forward(self, preds_S, preds_T):
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, H, W) or shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, H, W) or shape (N, C).
Return:
torch.Tensor: The calculated loss value.
"""
preds_S = preds_S.view(preds_S.shape[0], -1)
preds_T = preds_T.view(preds_T.shape[0], -1)
if self.with_l2_norm:
preds_S = F.normalize(preds_S, p=2, dim=-1)
preds_T = F.normalize(preds_T, p=2, dim=-1)

loss = self.angle_loss(preds_S, preds_T) * self.loss_weight

return loss
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Import:
- configs/distill/cwd/metafile.yml
- configs/distill/wsld/metafile.yml
- configs/distill/rkd/metafile.yml
- configs/nas/darts/metafile.yml
- configs/nas/detnas/metafile.yml
- configs/nas/spos/metafile.yml
Expand Down
Loading

0 comments on commit de4dd13

Please sign in to comment.