Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add the input channel dependency pruning. #2865

Merged
merged 5 commits into from
Sep 17, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ADD_TYPES = ['aten::add', 'aten::add_']
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')

RESHAPE_OPS = [CAT_TYPE, 'aten::view', 'aten::reshape', 'aten::flatten']

class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None):
Expand Down Expand Up @@ -185,6 +185,52 @@ def dependency_sets(self):
d_sets.append(tmp_set)
return d_sets

class InputChannelDependency(ChannelDependency):
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, model, dummy_input=None, traced_model=None):
super(InputChannelDependency, self).__init__(model, dummy_input, traced_model)

def _get_following_convs(self, tensor):
queue = []
key_layers = []
queue.extend(self.graph.input_to_node[tensor])
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear':
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
# find the first met conv
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
key_layers.append(curnode.name)
continue
elif curnode.op_type in RESHAPE_OPS:
# reshape operations also breaks the dependency relationship
continue
successors = self.graph.find_successors(curnode.unique_name)
successors = [self.graph.name_to_node[name] for name in successors]
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
for layer in successors:
queue.append(layer)
return key_layers

def build_dependency(self):
"""
Build the input channel dependencies.
The `InputChannelDependency` indicates the layers that have
dependencies when pruning the input channel of the conv layers.
In contrast, `ChannelDependency` indicates the dependent layers
when pruning the output channles of conv layers (for example, L1FilterPruner).
"""
# unpack the tuple or list manually
self.graph.unpack_manually()
for tensor in self.graph.input_to_node:
# start from this tensor, find all the conv layers that
# take this tensor as input. Similar to the `ChannelDependency`
# the conv layer will truncate the dependencies
layers = self._get_following_convs(tensor)
dependency_set = set(layers)
for layer in layers:
if layer in self.dependency:
dependency_set.update(self.dependency[layer])
for layer in dependency_set:
self.dependency[layer] = dependency_set


class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(model, dummy_input, traced_model)
Expand Down