Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
issue 4540 (#4594)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Mar 4, 2022
1 parent 21abc28 commit 3836689
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
10 changes: 7 additions & 3 deletions nni/compression/pytorch/speedup/infer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ def __apply_input_mask(self):
# apply the input mask
for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
# in_tensor.data = in_tensor.data * \
# self.in_masks[tid] + \
# (1-self.in_masks[tid]) * self.in_constants[tid]
# issue-4540 when two tensors are multiplied, the constants part make
# the propagation weaker, and lead to shape misaligment. Currently, we
# donnot support the constant folding, so, we just remove the constant here
in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \
(1-self.in_masks[tid]) * self.in_constants[tid]

self.in_masks[tid]

def __apply_weight_mask(self):
"""
Expand Down
8 changes: 7 additions & 1 deletion nni/compression/pytorch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ def build_dependency(self):
parent_layers = []
# find the node that contains aten::add
# or aten::cat operations
if node.op_type in ADD_TYPES:
if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
parent_layers = self._get_parent_layers(node)
elif node.op_type == CAT_TYPE:
# To determine if this cat operation will introduce channel
Expand Down
40 changes: 40 additions & 0 deletions test/ut/compression/v1/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,46 @@ def forward(self, x):
print("Fine-grained speeduped model")
print(model)

def test_multiplication_speedup(self):
"""
Model from issue 4540.
"""
class Net(torch.nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.input = torch.nn.Conv2d(3, 8, 3)
self.bn = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Conv2d(8, 16, 1)
self.fc2 = torch.nn.Conv2d(16, 8, 1)
self.activation = torch.nn.ReLU()
self.scale_activation = torch.nn.Hardsigmoid()
self.out = torch.nn.Conv2d(8, 12, 1)

def forward(self, input):
input = self.activation(self.bn(self.input(input)))
scale = self.avgpool(input)
out1 = self.activation(self.fc1(scale))
out1 = self.scale_activation(self.fc2(out1))
return self.out(out1 * input)

model = Net().to(device)
model.eval()
im = torch.ones(1, 3, 512, 512).to(device)
model(im)
cfg_list = []

for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.3, 'op_names':[name]})

pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms=ModelSpeedup(model, im, MASK_FILE)
ms.speedup_model()

def tearDown(self):
if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE)
Expand Down

0 comments on commit 3836689

Please sign in to comment.