Skip to content

Commit

Permalink
[BUG]Fix bugs in pruner (#126)
Browse files Browse the repository at this point in the history
* fix bugs in pruner when pruning models with shared modules

* pruner can trace models with dilation conv2d

* fix deploy_subnet

* fix add_pruning_attrs

* fix bugs in modify_forward

* fix lint

* fix StructurePruner

* test tracing models with shared modules

Co-authored-by: caoweihan <caoweihan@sensetime.com>
  • Loading branch information
HIT-cwh and caoweihan authored Apr 1, 2022
1 parent 31e21c0 commit 0ab5efe
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 12 deletions.
95 changes: 83 additions & 12 deletions mmrazor/models/pruners/structure_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule
from ordered_set import OrderedSet
from torch.nn.modules.batchnorm import _BatchNorm
Expand Down Expand Up @@ -89,6 +88,20 @@ def __init__(self, except_start_keys=['head.fc']):
else:
self.except_start_keys = except_start_keys

def trace_shared_module_hook(self, module, inputs, outputs):
"""Trace shared modules. Modules such as the detection head in
RetinaNet which are visited more than once during :func:`forward` are
shared modules.
Args:
module (:obj:`torch.nn.Module`): The module to register hook.
inputs (tuple): The input of the module.
outputs (tuple): The output of the module.
"""
module.cnt += 1
if module.cnt == 2:
self.shared_module.append(self.module2name[module])

def prepare_from_supernet(self, supernet):
"""Prepare for pruning."""

Expand All @@ -98,9 +111,25 @@ def prepare_from_supernet(self, supernet):

# record the visited module name during trace path
visited = dict()
# Record shared modules which will be visited more than once during
# forward such as shared detection head in RetinaNet.
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
self.shared_module = []
tmp_shared_module_hook_handles = list()

for name, module in supernet.model.named_modules():
if hasattr(module, 'weight'):
# trace shared modules
module.cnt = 0
# the handle is only to remove the corresponding hook later
handle = module.register_forward_hook(
self.trace_shared_module_hook)
tmp_shared_module_hook_handles.append(handle)

module2name[module] = name
name2module[name] = module
var2module[id(module.weight)] = module
Expand All @@ -109,12 +138,35 @@ def prepare_from_supernet(self, supernet):
if isinstance(module, SwitchableBatchNorm2d):
name2module[name] = module
self.name2module = name2module
self.module2name = module2name

# Set requires_grad to True. If the `requires_grad` of a module's
# weight is False, we can not trace this module by parsing backward.
param_require_grad = dict()
for param in supernet.model.parameters():
param_require_grad[id(param)] = param.requires_grad
param.requires_grad = True

pseudo_img = torch.randn(1, 3, 224, 224)
# todo: support two stage detector and mmseg
pseudo_img = supernet.forward_dummy(pseudo_img)
pseudo_loss = supernet.cal_pseudo_loss(pseudo_img)

# `trace_shared_module_hook` and `cnt` are only used to trace the
# shared modules in a model and need to be remove later
for name, module in supernet.model.named_modules():
if hasattr(module, 'weight'):
del module.cnt

for handle in tmp_shared_module_hook_handles:
handle.remove()

# We set requires_grad to True to trace the whole architecture
# topology. So it should be reset after that.
for param in supernet.model.parameters():
param.requires_grad = param_require_grad[id(param)]
del param_require_grad

non_pass_paths = list()
cur_non_pass_path = list()
self.trace_non_pass_path(pseudo_loss.grad_fn, module2name, var2module,
Expand Down Expand Up @@ -366,38 +418,39 @@ def make_same_out_channel_groups(self, node2parents, name2module):
@staticmethod
def modify_conv_forward(module):
"""Modify the forward method of a conv layer."""
original_forward = module.forward

def modified_forward(self, feature):
feature = feature * self.in_mask
return F.conv2d(feature, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return original_forward(feature)

return MethodType(modified_forward, module)

@staticmethod
def modify_fc_forward(module):
"""Modify the forward method of a linear layer."""
original_forward = module.forward

def modified_forward(self, feature):
if not len(self.in_mask.shape) == len(self.out_mask.shape):
self.in_mask = self.in_mask.reshape(self.in_mask.shape[:2])

feature = feature * self.in_mask
return F.linear(feature, self.weight, self.bias)
return original_forward(feature)

return MethodType(modified_forward, module)

def add_pruning_attrs(self, module):
"""Add masks to a ``nn.Module``."""
if type(module).__name__ == 'Conv2d':
if isinstance(module, nn.Conv2d):
module.register_buffer(
'in_mask',
module.weight.new_ones((1, module.in_channels, 1, 1), ))
module.register_buffer(
'out_mask',
module.weight.new_ones((1, module.out_channels, 1, 1), ))
module.forward = self.modify_conv_forward(module)
if type(module).__name__ == 'Linear':
if isinstance(module, nn.Linear):
module.register_buffer(
'in_mask', module.weight.new_ones((1, module.in_features), ))
module.register_buffer(
Expand Down Expand Up @@ -480,14 +533,16 @@ def deploy_subnet(self, supernet, channel_cfg):
module.out_channels = out_channels
if hasattr(module, 'out_features'):
module.out_features = out_channels
if hasattr(module, 'num_features'):
module.num_features = out_channels
if hasattr(module, 'out_mask'):
module.out_mask = module.out_mask[:, :out_channels]

if 'in_channels' in channels_per_layer:
in_channels = channels_per_layer['in_channels']

if in_channels > 1:
temp_weight = temp_weight[:, :in_channels].data
# can also handle depthwise conv
temp_weight = temp_weight[:, :in_channels].data
if hasattr(module, 'in_channels'):
module.in_channels = in_channels
if hasattr(module, 'in_features'):
Expand Down Expand Up @@ -632,6 +687,7 @@ def find_backward_parser(self, grad_fn):
@register_parser(BACKWARD_PARSER_DICT, 'ThnnConv2DBackward')
@register_parser(BACKWARD_PARSER_DICT, 'CudnnConvolutionBackward')
@register_parser(BACKWARD_PARSER_DICT, 'MkldnnConvolutionBackward')
@register_parser(BACKWARD_PARSER_DICT, 'SlowConvDilated2DBackward')
def conv_backward_parser(self, grad_fn, module2name, var2module, cur_path,
result_paths, visited):
"""Parse the backward of a conv layer.
Expand All @@ -656,7 +712,12 @@ def conv_backward_parser(self, grad_fn, module2name, var2module, cur_path,
name = module2name[module]
parent = grad_fn.next_functions[0][0]
cur_path.append(name)
if visited[name]:
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
if visited[name] and name not in self.shared_module:
result_paths.append(copy.deepcopy(cur_path))
else:
visited[name] = True
Expand Down Expand Up @@ -691,9 +752,13 @@ def linear_backward_parser(self, grad_fn, module2name, var2module,
module = var2module[var_id]
name = module2name[module]
parent = grad_fn.next_functions[1][0]

cur_path.append(name)
if visited[name]:
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
if visited[name] and name not in self.shared_module:
result_paths.append(copy.deepcopy(cur_path))
else:
visited[name] = True
Expand Down Expand Up @@ -722,7 +787,13 @@ def concat_backward_parser(self, grad_fn, module2name, var2module,
concat_id = '_'.join([str(id(p)) for p in parents])
name = f'concat_{concat_id}'
cur_path.append(name)
if name in visited and visited[name]:
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
if (name in visited and visited[name]
and name not in self.shared_module):
result_paths.append(copy.deepcopy(cur_path))
else:
visited[name] = True
Expand Down
81 changes: 81 additions & 0 deletions tests/test_models/test_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from mmcv import ConfigDict

from mmrazor.models.builder import ARCHITECTURES, PRUNERS

Expand Down Expand Up @@ -80,3 +81,83 @@ def test_ratio_pruner():
pruner.deploy_subnet(architecture, subnet_dict)
losses = architecture(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0

# test making groups logic when there are shared modules in the model
model_cfg = ConfigDict(
type='mmdet.RetinaNet',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))

architecture_cfg = dict(
type='MMDetArchitecture',
model=model_cfg,
)

pruner_cfg = dict(
type='RatioPruner',
ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0])

architecture = ARCHITECTURES.build(architecture_cfg)
pruner = PRUNERS.build(pruner_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
assert isinstance(subnet_dict, dict)
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
assert isinstance(subnet_dict, dict)
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)

0 comments on commit 0ab5efe

Please sign in to comment.