Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

issue 4540 #4594

Merged
merged 1 commit into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions nni/compression/pytorch/speedup/infer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ def __apply_input_mask(self):
# apply the input mask
for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
# in_tensor.data = in_tensor.data * \
# self.in_masks[tid] + \
# (1-self.in_masks[tid]) * self.in_constants[tid]
# issue-4540 when two tensors are multiplied, the constants part make
# the propagation weaker, and lead to shape misaligment. Currently, we
# donnot support the constant folding, so, we just remove the constant here
in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \
(1-self.in_masks[tid]) * self.in_constants[tid]

self.in_masks[tid]

def __apply_weight_mask(self):
"""
Expand Down
8 changes: 7 additions & 1 deletion nni/compression/pytorch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ def build_dependency(self):
parent_layers = []
# find the node that contains aten::add
# or aten::cat operations
if node.op_type in ADD_TYPES:
if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
parent_layers = self._get_parent_layers(node)
elif node.op_type == CAT_TYPE:
# To determine if this cat operation will introduce channel
Expand Down
40 changes: 40 additions & 0 deletions test/ut/compression/v1/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,46 @@ def forward(self, x):
print("Fine-grained speeduped model")
print(model)

def test_multiplication_speedup(self):
"""
Model from issue 4540.
"""
class Net(torch.nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.input = torch.nn.Conv2d(3, 8, 3)
self.bn = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Conv2d(8, 16, 1)
self.fc2 = torch.nn.Conv2d(16, 8, 1)
self.activation = torch.nn.ReLU()
self.scale_activation = torch.nn.Hardsigmoid()
self.out = torch.nn.Conv2d(8, 12, 1)

def forward(self, input):
input = self.activation(self.bn(self.input(input)))
scale = self.avgpool(input)
out1 = self.activation(self.fc1(scale))
out1 = self.scale_activation(self.fc2(out1))
return self.out(out1 * input)

model = Net().to(device)
model.eval()
im = torch.ones(1, 3, 512, 512).to(device)
model(im)
cfg_list = []

for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.3, 'op_names':[name]})

pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms=ModelSpeedup(model, im, MASK_FILE)
ms.speedup_model()

def tearDown(self):
if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE)
Expand Down