Skip to content

Commit ca07946

Browse files
committed
feat: support distillation strategy
1 parent 85450da commit ca07946

File tree

5 files changed

+131
-2
lines changed

5 files changed

+131
-2
lines changed

config.py

+14
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,20 @@ def create_parser():
256256
group.add_argument('--drop_overflow_update', type=bool, default=False,
257257
help='Whether to execute optimizer if there is an overflow (default=False)')
258258

259+
# distillation
260+
group = parser.add_argument_group('Distillation parameters')
261+
group.add_argument('--distillation_type', type=str, default=None,
262+
choices=['hard', 'soft'],
263+
help='The type of distillation (default=None)')
264+
group.add_argument('--teacher_model', type=str, default=None,
265+
help='Name of teacher model (default=None)')
266+
group.add_argument('--teacher_ckpt_path', type=str, default='',
267+
help='Initialize teacher model from this checkpoint. '
268+
'If resume training, specify the checkpoint path (default="").')
269+
group.add_argument('--distillation_alpha', type=float, default=0.5,
270+
help='The coefficient balancing the distillation loss and base loss'
271+
'(default=0.5)')
272+
259273
# modelarts
260274
group = parser.add_argument_group('modelarts')
261275
group.add_argument('--enable_modelarts', type=str2bool, nargs='?', const=True, default=False,

mindcv/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .amp import *
33
from .callbacks import *
44
from .checkpoint_manager import *
5+
from .distill_loss_cell import *
56
from .download import *
67
from .logger import *
78
from .path import *

mindcv/utils/distill_loss_cell.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
""" distillation loss cell define """
2+
from mindspore import nn
3+
from mindspore.ops import functional as F
4+
5+
6+
class HardDistillLossCell(nn.WithLossCell):
7+
"""
8+
Wraps the network with hard distillation loss function.
9+
10+
Get the loss of student network and an extra knowledge hard distillation loss
11+
by taking a teacher model prediction and using it as additional supervision.
12+
13+
Args:
14+
backbone (Cell): The student network to train and calculate base loss.
15+
loss_fn (Cell): The loss function used to compute loss of student network.
16+
teacher_model (Cell): The teacher network to calculate distillation loss.
17+
alpha (float): Distillation factor. the coefficient to balance the distillation
18+
loss and base loss. Default: 0.5.
19+
"""
20+
21+
def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5):
22+
super().__init__(backbone, loss_fn)
23+
self.teacher_model = teacher_model
24+
self.alpha = alpha
25+
26+
def construct(self, data, label):
27+
out = self._backbone(data)
28+
29+
out, out_kd = out
30+
base_loss = self._loss_fn(out, label)
31+
32+
teacher_out = self.teacher_model(data)
33+
34+
distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1))
35+
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
36+
37+
return loss
38+
39+
40+
class SoftDistillLossCell(nn.WithLossCell):
41+
"""
42+
Wraps the network with soft distillation loss function.
43+
44+
Get the loss of student network and an extra knowledge soft distillation loss
45+
by taking a teacher model prediction and using it as additional supervision.
46+
47+
Args:
48+
backbone (Cell): The student network to train and calculate base loss.
49+
loss_fn (Cell): The loss function used to compute loss of student network.
50+
teacher_model (Cell): The teacher network to calculate distillation loss.
51+
alpha (float): Distillation factor. the coefficient balancing the distillation
52+
loss and base loss. Default: 0.5.
53+
tau (float): Distillation temperature. The higher the temperature, the lower the
54+
dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0.
55+
"""
56+
57+
def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5, tau=1.0):
58+
super().__init__(backbone, loss_fn)
59+
self.teacher_model = teacher_model
60+
self.alpha = alpha
61+
self.tau = tau
62+
63+
def construct(self, data, label):
64+
out = self._backbone(data)
65+
66+
out, out_kd = out
67+
base_loss = self._loss_fn(out, label)
68+
69+
teacher_out = self.teacher_model(data)
70+
71+
T = self.tau
72+
distillation_loss = (
73+
F.kl_div(
74+
F.log_softmax(out_kd / T, axis=1),
75+
F.log_softmax(teacher_out / T, axis=1),
76+
reduction="sum",
77+
)
78+
* (T * T)
79+
/ F.size(out_kd)
80+
)
81+
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
82+
83+
return loss

mindcv/utils/trainer_factory.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model
1010

1111
from .amp import auto_mixed_precision
12+
from .distill_loss_cell import HardDistillLossCell, SoftDistillLossCell
1213
from .train_step import TrainStep
1314

1415
__all__ = [
@@ -38,6 +39,7 @@ def require_customized_train_step(
3839
clip_grad: bool = False,
3940
gradient_accumulation_steps: int = 1,
4041
amp_cast_list: Optional[str] = None,
42+
distillation_type: Optional[str] = None,
4143
):
4244
if ema:
4345
return True
@@ -47,6 +49,8 @@ def require_customized_train_step(
4749
return True
4850
if amp_cast_list:
4951
return True
52+
if distillation_type:
53+
return True
5054
return False
5155

5256

@@ -88,6 +92,9 @@ def create_trainer(
8892
clip_grad: bool = False,
8993
clip_value: float = 15.0,
9094
gradient_accumulation_steps: int = 1,
95+
distillation_type: Optional[str] = None,
96+
teacher_network: Optional[nn.Cell] = None,
97+
distillation_alpha: float = 0.5,
9198
):
9299
"""Create Trainer.
93100
@@ -120,7 +127,7 @@ def create_trainer(
120127
if gradient_accumulation_steps < 1:
121128
raise ValueError("`gradient_accumulation_steps` must be >= 1!")
122129

123-
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list):
130+
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list, distillation_type):
124131
mindspore_kwargs = dict(
125132
network=network,
126133
loss_fn=loss,
@@ -149,7 +156,15 @@ def create_trainer(
149156
else: # require customized train step
150157
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
151158
auto_mixed_precision(network, amp_level, amp_cast_list)
152-
net_with_loss = add_loss_network(network, loss, amp_level)
159+
if distillation_type:
160+
if distillation_type == "hard":
161+
net_with_loss = HardDistillLossCell(network, loss, teacher_network, distillation_alpha)
162+
elif distillation_type == "soft":
163+
net_with_loss = SoftDistillLossCell(network, loss, teacher_network, distillation_alpha)
164+
else:
165+
raise ValueError(f"Distillation type only support ['hard', 'soft'], but got {distillation_type}.")
166+
else:
167+
net_with_loss = add_loss_network(network, loss, amp_level)
153168
train_step_kwargs = dict(
154169
network=net_with_loss,
155170
optimizer=optimizer,

train.py

+16
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,18 @@ def train(args):
180180
aux_factor=args.aux_factor,
181181
)
182182

183+
# create teacher model
184+
teacher_network = None
185+
if args.distillation_type:
186+
if not args.teacher_ckpt_path:
187+
logger.warning("You are using distillation, but your teacher model has not loaded weights.")
188+
teacher_network = create_model(
189+
model_name=args.teacher_model,
190+
num_classes=num_classes,
191+
checkpoint_path=args.teacher_ckpt_path,
192+
)
193+
teacher_network.set_train(False)
194+
183195
# create learning rate schedule
184196
lr_scheduler = create_scheduler(
185197
num_batches,
@@ -213,6 +225,7 @@ def train(args):
213225
args.clip_grad,
214226
args.gradient_accumulation_steps,
215227
args.amp_cast_list,
228+
args.distillation_type,
216229
)
217230
):
218231
optimizer_loss_scale = args.loss_scale
@@ -250,6 +263,9 @@ def train(args):
250263
clip_grad=args.clip_grad,
251264
clip_value=args.clip_value,
252265
gradient_accumulation_steps=args.gradient_accumulation_steps,
266+
distillation_type=args.distillation_type,
267+
teacher_network=teacher_network,
268+
distillation_alpha=args.distillation_alpha,
253269
)
254270

255271
# callback

0 commit comments

Comments
 (0)