Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dependency aware pruning and speeding up network with squeeze-excite blocks #4720

Closed
lopusz opened this issue Mar 30, 2022 · 1 comment
Closed

Comments

@lopusz
Copy link

lopusz commented Mar 30, 2022

Describe the issue:

When I try to do dependency aware pruning & speed-up models with squeeze and excite blocks, I receive error messages.

Everything goes smoothly i.e. pruning and speed-up gives no error. However, when I try to do predictions I get inconsistent shape errors as this one:

RuntimeError: The size of tensor a (156) must match the size of tensor b (159) at non-singleton dimension 

Do I need some special config for this case?

Thank you in advance for your help.

Environment:
- NNI version: nni==2.6.1
- Training service (local|remote|pai|aml|etc): local
- Client OS: Ubuntu 20.04
- Server OS (for remote mode only):

  • Python version: 3.8.5
    - PyTorch/TensorFlow version: 1.11
    - Is conda/virtualenv/venv used?: no
    - Is running in Docker?: no

Configuration:

  • Experiment config (remember to remove secrets!):
  • Search space:

Log message:

  • nnimanager.log:
  • dispatcher.log:
  • nnictl stdout and stderr:
[2022-03-30 16:48:48] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: block4a_project_conv
[2022-03-30 16:48:48] INFO (torch filter pruners/MainThread) Prune the 8,12,21,49,53,77,87,88,92 channels for all dependent
[2022-03-30 16:48:48] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: block4a_expand_conv
[2022-03-30 16:48:48] INFO (torch filter pruners/MainThread) Prune the 6,11,12,18,23,25,26,40,46,62,68,73,89,92,123,155,163,166,186 channels for all dependent
[2022-03-30 16:48:48] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: block4a_dwconv2
[2022-03-30 16:48:48] INFO (torch filter pruners/MainThread) Prune the 9,30,45,50,65,77,79,90,93,110,113,121,123,132,134,180,184,188,190 channels for all dependent
[2022-03-30 16:48:48] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: block4a_se_reduce
[2022-03-30 16:48:48] INFO (torch filter pruners/MainThread) Prune the 3 channels for all dependent
[2022-03-30 16:48:48] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: block4a_se_expand
[2022-03-30 16:48:48] INFO (torch filter pruners/MainThread) Prune the 2,9,37,41,43,44,67,77,78,103,122,138,144,150,155,166,173,186,191 channels for all dependent
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to ./pruned_model.pth
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to ./mask.pth
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model
[2022-03-30 16:48:48] INFO (FixMaskConflict/MainThread) {'block4a_expand_conv': 1, 'block4a_dwconv2': 1, 'block4a_se_reduce': 1, 'block4a_se_expand': 1, 'block4a_project_conv': 1}
[2022-03-30 16:48:48] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.097953
[2022-03-30 16:48:48] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000
[2022-03-30 16:48:48] INFO (FixMaskConflict/MainThread) Dectected conv prune dim" 0
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_expand_conv
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_expand_bn
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::relu.7
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_dwconv2
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_bn
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::relu.8
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::mean.9
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_se_reduce
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::relu.10
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_se_expand
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::sigmoid.11
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::mul.12
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for block4a_project_conv
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_project_conv
/home/lopusz/PyEnvs/venvs/p39-tmp/lib/python3.9/site-packages/torch/_tensor.py:1104: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.)
 return self._grad
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::mul.12
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::sigmoid.11
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_se_expand
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::relu.10
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_se_reduce
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::mean.9
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::relu.8
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_bn
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_dwconv2
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::relu.7
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_expand_bn
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the block4a_expand_conv
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) resolve the mask conflict
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace compressed modules...
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_expand_conv, op_type: Conv2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_expand_bn, op_type: BatchNorm2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace batchnorm2d with num_features: 155
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::relu.7, op_type: aten::relu) which is func type
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_dwconv2, op_type: Conv2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_bn, op_type: BatchNorm2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace batchnorm2d with num_features: 155
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::relu.8, op_type: aten::relu) which is func type
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::mean.9, op_type: aten::mean) which is func type
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_se_reduce, op_type: Conv2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::relu.10, op_type: aten::relu) which is func type
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_se_expand, op_type: Conv2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::sigmoid.11, op_type: aten::sigmoid) which is func type
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::mul.12, op_type: aten::mul) which is func type
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: block4a_project_conv, op_type: Conv2d)
[2022-03-30 16:48:48] INFO (nni.compression.pytorch.speedup.compressor/MainThread) speedup done
Traceback (most recent call last):
 File "/home/lopusz/TCL/ml-experiments/2022-03-03-torchify/tmp.py", line 92, in <module>
   dummy_output2 = model(dummy_input)
 File "/home/lopusz/PyEnvs/venvs/p39-tmp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/home/lopusz/TCL/ml-experiments/2022-03-03-torchify/tmp.py", line 74, in forward
   x_41 = x_35 * x_40
RuntimeError: The size of tensor a (155) must match the size of tensor b (151) at non-singleton dimension 1

How to reproduce it?:

import torch

from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.compression.pytorch import ModelSpeedup


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__() 
        self.block4a_project_conv = torch.nn.Conv2d(
            in_channels=192,
            out_channels=96,
            kernel_size=1,
            stride=1,
            bias=False,
            padding="same",
        )

        self.block4a_expand_conv = torch.nn.Conv2d(
            in_channels=48,
            out_channels=192,
            kernel_size=1,
            stride=1,
            bias=False,
            padding="same",
        )
        self.block4a_expand_bn = torch.nn.BatchNorm2d(
            num_features=192, eps=0.001, momentum=0.9
        )
        self.block4a_dwconv2 = torch.nn.Conv2d(
            in_channels=192,
            out_channels=192,
            kernel_size=3,
            stride=2,
            bias=False,
            padding=0,
            groups=192,
        )
        self.block4a_bn = torch.nn.BatchNorm2d(
            num_features=192, eps=0.001, momentum=0.9
        )
        self.block4a_se_reduce = torch.nn.Conv2d(
            in_channels=192,
            out_channels=12,
            kernel_size=1,
            stride=1,
            bias=True,
            padding="same",
        )
        self.block4a_se_expand = torch.nn.Conv2d(
            in_channels=12,
            out_channels=192,
            kernel_size=1,
            stride=1,
            bias=True,
            padding="same",
        )
        self.block4a_project_bn = torch.nn.BatchNorm2d(
            num_features=96, eps=0.001, momentum=0.9
        )

    def forward(self, x_29):
        x_30 = self.block4a_expand_conv(x_29)
        x_31 = self.block4a_expand_bn(x_30)
        x_32 = torch.nn.functional.relu(x_31)
        x_33 = self.block4a_dwconv2(x_32)
        x_34 = self.block4a_bn(x_33)
        x_35 = torch.nn.functional.relu(x_34)
        x_36 = torch.mean(x_35, [2, 3], keepdim=True)
        x_37 = self.block4a_se_reduce(x_36)
        x_38 = torch.nn.functional.relu(x_37)
        x_39 = self.block4a_se_expand(x_38)
        x_40 = torch.sigmoid(x_39)
        x_41 = x_35 * x_40
        x_42 = self.block4a_project_conv(x_41)
        return x_42


model = Model()

# The module is take from inner part of the model, so shapes are wacky

dummy_input = torch.rand(1, 48, 28, 28)
dummy_output = model(dummy_input)
config_list = [{"op_types": ["Conv2d"], "sparsity": 0.1}]
pruner = L1FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
pruner.export_model("./pruned_model.pth", "./mask.pth")

pruner._unwrap_model()

ms = ModelSpeedup(model, dummy_input, "./mask.pth")
ms.speedup_model()

# Everything seems to go fine, until one tries predicting we the model after speed-up

dummy_output2 = model(dummy_input)
@lopusz lopusz changed the title Problem dependency aware pruning and speeding up network with squeeze-excite blocks Dependency aware pruning and speeding up network with squeeze-excite blocks Mar 30, 2022
@J-shang
Copy link
Contributor

J-shang commented Apr 1, 2022

Hello @lopusz , thanks for your issue, seems we have fixed this in #4594, you could merge these changes. Or we have a test package on v2.7 now, you could also try this by python3 -m pip install --extra-index-url https://test.pypi.org/simple/ nni==2.7a1.

If it still doesn't solve the problem, feel free to contact us.

@J-shang J-shang self-assigned this Apr 1, 2022
@J-shang J-shang closed this as completed Sep 7, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants