From 26f7064af41b7e4f4cad208d9fc1a2540a47da89 Mon Sep 17 00:00:00 2001 From: nizhen Date: Sat, 26 Feb 2022 12:29:00 +0900 Subject: [PATCH] issue 4540 --- nni/compression/pytorch/speedup/infer_mask.py | 10 +++-- .../pytorch/utils/shape_dependency.py | 8 +++- test/ut/compression/v1/test_model_speedup.py | 40 +++++++++++++++++++ 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/nni/compression/pytorch/speedup/infer_mask.py b/nni/compression/pytorch/speedup/infer_mask.py index 8ace639207..312f60ecd4 100644 --- a/nni/compression/pytorch/speedup/infer_mask.py +++ b/nni/compression/pytorch/speedup/infer_mask.py @@ -171,10 +171,14 @@ def __apply_input_mask(self): # apply the input mask for tid, in_tensor in enumerate(self.dummy_input): if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None: + # in_tensor.data = in_tensor.data * \ + # self.in_masks[tid] + \ + # (1-self.in_masks[tid]) * self.in_constants[tid] + # issue-4540 when two tensors are multiplied, the constants part make + # the propagation weaker, and lead to shape misaligment. Currently, we + # donnot support the constant folding, so, we just remove the constant here in_tensor.data = in_tensor.data * \ - self.in_masks[tid] + \ - (1-self.in_masks[tid]) * self.in_constants[tid] - + self.in_masks[tid] def __apply_weight_mask(self): """ diff --git a/nni/compression/pytorch/utils/shape_dependency.py b/nni/compression/pytorch/utils/shape_dependency.py index f972212a5a..436e84139b 100644 --- a/nni/compression/pytorch/utils/shape_dependency.py +++ b/nni/compression/pytorch/utils/shape_dependency.py @@ -163,7 +163,13 @@ def build_dependency(self): parent_layers = [] # find the node that contains aten::add # or aten::cat operations - if node.op_type in ADD_TYPES: + if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES: + # refer issue 4540 for more details. Multiplication actually + # will not introduce the channel dependency, cause the misaligned + # channels can propagate to each other. However, when one of the input + # tensor is from skip connection(residual), the channel propagation + # may be failed(the input is also used by another layer and cannot be + # pruned), in this case, we need to fix the conflict maunally. parent_layers = self._get_parent_layers(node) elif node.op_type == CAT_TYPE: # To determine if this cat operation will introduce channel diff --git a/test/ut/compression/v1/test_model_speedup.py b/test/ut/compression/v1/test_model_speedup.py index 9d0ff7cf86..a70f010f6d 100644 --- a/test/ut/compression/v1/test_model_speedup.py +++ b/test/ut/compression/v1/test_model_speedup.py @@ -512,6 +512,46 @@ def forward(self, x): print("Fine-grained speeduped model") print(model) + def test_multiplication_speedup(self): + """ + Model from issue 4540. + """ + class Net(torch.nn.Module): + def __init__(self,): + super(Net, self).__init__() + self.avgpool = torch.nn.AdaptiveAvgPool2d(1) + self.input = torch.nn.Conv2d(3, 8, 3) + self.bn = torch.nn.BatchNorm2d(8) + self.fc1 = torch.nn.Conv2d(8, 16, 1) + self.fc2 = torch.nn.Conv2d(16, 8, 1) + self.activation = torch.nn.ReLU() + self.scale_activation = torch.nn.Hardsigmoid() + self.out = torch.nn.Conv2d(8, 12, 1) + + def forward(self, input): + input = self.activation(self.bn(self.input(input))) + scale = self.avgpool(input) + out1 = self.activation(self.fc1(scale)) + out1 = self.scale_activation(self.fc2(out1)) + return self.out(out1 * input) + + model = Net().to(device) + model.eval() + im = torch.ones(1, 3, 512, 512).to(device) + model(im) + cfg_list = [] + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.3, 'op_names':[name]}) + + pruner = L1FilterPruner(model, cfg_list) + pruner.compress() + pruner.export_model(MODEL_FILE, MASK_FILE) + pruner._unwrap_model() + ms=ModelSpeedup(model, im, MASK_FILE) + ms.speedup_model() + def tearDown(self): if os.path.exists(MODEL_FILE): os.remove(MODEL_FILE)