Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 79d7428

Browse files
committedAug 8, 2023
feat: support distillation strategy
1 parent 85450da commit 79d7428

File tree

5 files changed

+137
-2
lines changed

5 files changed

+137
-2
lines changed
 

‎config.py

+16
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,22 @@ def create_parser():
256256
group.add_argument('--drop_overflow_update', type=bool, default=False,
257257
help='Whether to execute optimizer if there is an overflow (default=False)')
258258

259+
# distillation
260+
group = parser.add_argument_group('Distillation parameters')
261+
group.add_argument('--distillation_type', type=str, default=None,
262+
choices=['hard', 'soft'],
263+
help='The type of distillation (default=None)')
264+
group.add_argument('--teacher_model', type=str, default=None,
265+
help='Name of teacher model (default=None)')
266+
group.add_argument('--teacher_ckpt_path', type=str, default='',
267+
help='Initialize teacher model from this checkpoint. '
268+
'If use distillation, specify the checkpoint path (default="").')
269+
group.add_argument('--teacher_ema', type=str2bool, nargs='?', const=True, default=False,
270+
help='Whether teacher model training with ema (default=False)')
271+
group.add_argument('--distillation_alpha', type=float, default=0.5,
272+
help='The coefficient to balance the distillation loss and base loss. '
273+
'(default=0.5)')
274+
259275
# modelarts
260276
group = parser.add_argument_group('modelarts')
261277
group.add_argument('--enable_modelarts', type=str2bool, nargs='?', const=True, default=False,

‎mindcv/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .amp import *
33
from .callbacks import *
44
from .checkpoint_manager import *
5+
from .distillation import *
56
from .download import *
67
from .logger import *
78
from .path import *

‎mindcv/utils/distillation.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
""" distillation related functions """
2+
from types import MethodType
3+
4+
import mindspore as ms
5+
from mindspore import nn
6+
from mindspore.ops import functional as F
7+
8+
9+
class DistillLossCell(nn.WithLossCell):
10+
"""
11+
Wraps the network with hard distillation loss function.
12+
13+
Get the loss of student network and an extra knowledge distillation loss by
14+
taking a teacher model prediction and using it as additional supervision.
15+
16+
Args:
17+
backbone (Cell): The student network to train and calculate base loss.
18+
loss_fn (Cell): The loss function used to compute loss of student network.
19+
distillation_type (str): The type of distillation.
20+
teacher_model (Cell): The teacher network to calculate distillation loss.
21+
alpha (float): The coefficient to balance the distillation loss and base loss. Default: 0.5.
22+
tau (float): Distillation temperature. The higher the temperature, the lower the
23+
dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0.
24+
"""
25+
26+
def __init__(self, backbone, loss_fn, distillation_type, teacher_model, alpha=0.5, tau=1.0):
27+
super().__init__(backbone, loss_fn)
28+
if distillation_type == "hard":
29+
self.hard_type = True
30+
elif distillation_type == "soft":
31+
self.hard_type = False
32+
else:
33+
raise ValueError(f"Distillation type only support ['hard', 'soft'], but got {distillation_type}.")
34+
self.teacher_model = teacher_model
35+
self.alpha = alpha
36+
self.tau = tau
37+
38+
def construct(self, data, label):
39+
out = self._backbone(data)
40+
41+
out, out_kd = out
42+
base_loss = self._loss_fn(out, label)
43+
44+
teacher_out = F.stop_gradient(self.teacher_model(data))
45+
46+
if self.hard_type:
47+
distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1))
48+
else:
49+
T = self.tau
50+
out_kd = F.cast(out_kd, ms.float32)
51+
distillation_loss = (
52+
F.kl_div(
53+
F.log_softmax(out_kd / T, axis=1),
54+
F.log_softmax(teacher_out / T, axis=1),
55+
reduction="sum",
56+
)
57+
* (T * T)
58+
/ F.size(out_kd)
59+
)
60+
61+
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
62+
63+
return loss
64+
65+
66+
def bn_infer_only(self, x):
67+
return self.bn_infer(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]
68+
69+
70+
def dropout_infer_only(self, x):
71+
return x
72+
73+
74+
def set_validation(network):
75+
"""
76+
Since MindSpore cannot automatically set some cells to validation mode
77+
during training in the teacher network, we need to manually set these
78+
cells to validation mode in this function.
79+
"""
80+
81+
for _, cell in network.cells_and_names():
82+
if isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
83+
cell.construct = MethodType(bn_infer_only, cell)
84+
elif isinstance(cell, nn.Dropout):
85+
cell.construct = MethodType(dropout_infer_only, cell)
86+
else:
87+
cell.set_train(False)

‎mindcv/utils/trainer_factory.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model
1010

1111
from .amp import auto_mixed_precision
12+
from .distillation import DistillLossCell
1213
from .train_step import TrainStep
1314

1415
__all__ = [
@@ -38,6 +39,7 @@ def require_customized_train_step(
3839
clip_grad: bool = False,
3940
gradient_accumulation_steps: int = 1,
4041
amp_cast_list: Optional[str] = None,
42+
distillation_type: Optional[str] = None,
4143
):
4244
if ema:
4345
return True
@@ -47,6 +49,8 @@ def require_customized_train_step(
4749
return True
4850
if amp_cast_list:
4951
return True
52+
if distillation_type:
53+
return True
5054
return False
5155

5256

@@ -88,6 +92,9 @@ def create_trainer(
8892
clip_grad: bool = False,
8993
clip_value: float = 15.0,
9094
gradient_accumulation_steps: int = 1,
95+
distillation_type: Optional[str] = None,
96+
teacher_network: Optional[nn.Cell] = None,
97+
distillation_alpha: float = 0.5,
9198
):
9299
"""Create Trainer.
93100
@@ -106,6 +113,9 @@ def create_trainer(
106113
clip_grad: whether to gradient clip.
107114
clip_value: The value at which to clip gradients.
108115
gradient_accumulation_steps: Accumulate the gradients of n batches before update.
116+
distillation_type: The type of distillation.
117+
teacher_network: The teacher network for distillation.
118+
distillation_alpha: The coefficient to balance the distillation loss and base loss.
109119
110120
Returns:
111121
mindspore.Model
@@ -120,7 +130,7 @@ def create_trainer(
120130
if gradient_accumulation_steps < 1:
121131
raise ValueError("`gradient_accumulation_steps` must be >= 1!")
122132

123-
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list):
133+
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list, distillation_type):
124134
mindspore_kwargs = dict(
125135
network=network,
126136
loss_fn=loss,
@@ -149,7 +159,10 @@ def create_trainer(
149159
else: # require customized train step
150160
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
151161
auto_mixed_precision(network, amp_level, amp_cast_list)
152-
net_with_loss = add_loss_network(network, loss, amp_level)
162+
if distillation_type:
163+
net_with_loss = DistillLossCell(network, loss, distillation_type, teacher_network, distillation_alpha)
164+
else:
165+
net_with_loss = add_loss_network(network, loss, amp_level)
153166
train_step_kwargs = dict(
154167
network=net_with_loss,
155168
optimizer=optimizer,

‎train.py

+18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
require_customized_train_step,
2020
set_logger,
2121
set_seed,
22+
set_validation,
2223
)
2324

2425
from config import parse_args, save_args # isort: skip
@@ -180,6 +181,19 @@ def train(args):
180181
aux_factor=args.aux_factor,
181182
)
182183

184+
# create teacher model
185+
teacher_network = None
186+
if args.distillation_type:
187+
if not args.teacher_ckpt_path:
188+
logger.warning("You are using distillation, but your teacher model has not loaded weights.")
189+
teacher_network = create_model(
190+
model_name=args.teacher_model,
191+
num_classes=num_classes,
192+
checkpoint_path=args.teacher_ckpt_path,
193+
ema=args.teacher_ema,
194+
)
195+
set_validation(teacher_network)
196+
183197
# create learning rate schedule
184198
lr_scheduler = create_scheduler(
185199
num_batches,
@@ -213,6 +227,7 @@ def train(args):
213227
args.clip_grad,
214228
args.gradient_accumulation_steps,
215229
args.amp_cast_list,
230+
args.distillation_type,
216231
)
217232
):
218233
optimizer_loss_scale = args.loss_scale
@@ -250,6 +265,9 @@ def train(args):
250265
clip_grad=args.clip_grad,
251266
clip_value=args.clip_value,
252267
gradient_accumulation_steps=args.gradient_accumulation_steps,
268+
distillation_type=args.distillation_type,
269+
teacher_network=teacher_network,
270+
distillation_alpha=args.distillation_alpha,
253271
)
254272

255273
# callback

0 commit comments

Comments
 (0)
Please sign in to comment.