Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]Support Relational Knowledge Distillation #127

Merged
merged 11 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
))

# teacher settings
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved

# FIXME: replace it with your own path
teacher_ckpt = 'path/to/your/checkpoint.pth'

teacher = dict(
type='mmcls.ImageClassifier',
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
backbone=dict(
type='ResNet',
depth=34,
Expand Down Expand Up @@ -64,7 +69,7 @@
name='loss_rkd',
loss_weight_d=25.0,
loss_weight_a=50.0,
l2_norm=True)
with_l2_norm=True)
])
]),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
))

# teacher settings

# FIXME: replace it with your own path
teacher_ckpt = 'path/to/your/checkpoint.pth'
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved

teacher = dict(
type='mmcls.ImageClassifier',
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
backbone=dict(
type='ResNet_CIFAR',
depth=50,
Expand Down Expand Up @@ -58,7 +63,7 @@
name='loss_rkd',
loss_weight_d=25.0,
loss_weight_a=50.0,
l2_norm=True)
with_l2_norm=True)
])
]),
)
Expand Down
103 changes: 58 additions & 45 deletions mmrazor/models/losses/rkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,46 @@
from ..builder import LOSSES


def euclidean_distance(pred, squared=False, eps=1e-12):
"""Calculate the Euclidean distance between the two examples in the output
representation space.

Args:
pred (torch.Tensor): The prediction of the teacher or student with
shape (N, C).
squared (bool): Whether to calculate the squared Euclidean
distance. Defaults to False.
eps (float): The minimum Euclidean distance between the two
examples. Defaults to 1e-12.
"""
pred_square = pred.pow(2).sum(dim=-1) # (N, )
prod = torch.mm(pred, pred.t()) # (N, N)
distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) -
2 * prod).clamp(min=eps) # (N, N)

if not squared:
distance = distance.sqrt()

distance = distance.clone()
distance[range(len(prod)), range(len(prod))] = 0
return distance


def angle(pred):
"""Calculate the angle-wise relational potential which measures the angle
formed by the three examples in the output representation space.

Args:
pred (torch.Tensor): The prediction of the teacher or student with
shape (N, C).
"""
pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1) # (N, N, C)
norm_pred_vec = F.normalize(pred_vec, p=2, dim=2)
angle = torch.bmm(norm_pred_vec,
norm_pred_vec.transpose(1, 2)).view(-1) # (N*N*N, )
return angle


@LOSSES.register_module()
class RelationalKD(nn.Module):
"""PyTorch version of `Relational Knowledge Distillation.
Expand All @@ -16,67 +56,36 @@ class RelationalKD(nn.Module):
Defaults to 25.0.
loss_weight_a (float): Weight of angle-wise distillation loss.
Defaults to 50.0.
l2_norm (bool): Whether to normalize the model predictions before
with_l2_norm (bool): Whether to normalize the model predictions before
calculating the loss. Defaults to True.
"""

def __init__(self, loss_weight_d=25.0, loss_weight_a=50.0, l2_norm=True):
def __init__(self,
loss_weight_d=25.0,
loss_weight_a=50.0,
with_l2_norm=True):
super(RelationalKD, self).__init__()
self.loss_weight_d = loss_weight_d
self.loss_weight_a = loss_weight_a
self.l2_norm = l2_norm

def euclidean_distance(self, pred, squared=False, eps=1e-12):
"""Calculate the Euclidean distance between the two examples in the
output representation space.

Args:
pred (torch.Tensor): The prediction of the teacher or student with
shape (N, C).
squared (bool): Whether to calculate the squared Euclidean
distance. Defaults to False.
eps (float): The minimum Euclidean distance between the two
examples. Defaults to 1e-12.
"""
pred_square = pred.pow(2).sum(dim=-1)
prod = torch.mm(pred, pred.t())
distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) -
2 * prod).clamp(min=eps)

if not squared:
distance = distance.sqrt()

distance = distance.clone()
distance[range(len(prod)), range(len(prod))] = 0
return distance
self.with_l2_norm = with_l2_norm

def distance_loss(self, preds_S, preds_T):
"""Calculate distance-wise distillation loss."""
d_T = self.euclidean_distance(preds_T, squared=False)
d_T = euclidean_distance(preds_T, squared=False)
# mean_d_T is a normalization factor for distance
mean_d_T = d_T[d_T > 0].mean()
d_T = d_T / mean_d_T

d_S = self.euclidean_distance(preds_S, squared=False)
d_S = euclidean_distance(preds_S, squared=False)
mean_d_S = d_S[d_S > 0].mean()
d_S = d_S / mean_d_S

return F.smooth_l1_loss(d_S, d_T)

def angle(self, pred):
"""Calculate the angle-wise relational potential which measures the
angle formed by the three examples in the output representation
space."""
pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1) # (n, n, c)
norm_pred_vec = F.normalize(pred_vec, p=2, dim=2)
angle = torch.bmm(norm_pred_vec, norm_pred_vec.transpose(1,
2)).view(-1)
return angle

def angle_loss(self, preds_S, preds_T):
"""Calculate the angle-wise distillation loss."""
angle_T = self.angle(preds_T)
angle_S = self.angle(preds_S)
angle_T = angle(preds_T)
angle_S = angle(preds_S)
return F.smooth_l1_loss(angle_S, angle_T)

def forward(self, preds_S, preds_T):
Expand All @@ -92,10 +101,14 @@ def forward(self, preds_S, preds_T):
"""
preds_S = preds_S.view(preds_S.shape[0], -1)
preds_T = preds_T.view(preds_T.shape[0], -1)
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved
if self.l2_norm:
if self.with_l2_norm:
preds_S = F.normalize(preds_S, p=2, dim=-1)
preds_T = F.normalize(preds_T, p=2, dim=-1)
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved
loss_d = self.distance_loss(preds_S, preds_T)
loss_a = self.angle_loss(preds_S, preds_T)
loss = self.loss_weight_d * loss_d + self.loss_weight_a * loss_a

loss = 0.
if self.loss_weight_d > 0:
loss += self.distance_loss(preds_S, preds_T) * self.loss_weight_d
if self.loss_weight_a > 0:
loss += self.angle_loss(preds_S, preds_T) * self.loss_weight_a

return loss
4 changes: 2 additions & 2 deletions tests/test_models/test_algorithms/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def test_rkd():
name='loss_rkd',
loss_weight_d=25.0,
loss_weight_a=50.0,
l2_norm=True)
with_l2_norm=True)
])
]),
)
Expand Down Expand Up @@ -527,7 +527,7 @@ def test_rkd():
name='loss_rkd',
loss_weight_d=25.0,
loss_weight_a=50.0,
l2_norm=False)
with_l2_norm=False)
])
]

Expand Down