diff --git a/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py b/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py index 1c32fdb227..40e445cf19 100644 --- a/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py +++ b/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py @@ -4,13 +4,13 @@ import csv import logging -__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency'] +__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency', 'InputChannelDependency'] CONV_TYPE = 'aten::_convolution' ADD_TYPES = ['aten::add', 'aten::add_'] CAT_TYPE = 'aten::cat' logger = logging.getLogger('Shape_Dependency') - +RESHAPE_OPS = [CAT_TYPE, 'aten::view', 'aten::reshape', 'aten::flatten', 'aten::mean'] class Dependency: def __init__(self, model=None, dummy_input=None, traced_model=None): @@ -37,7 +37,7 @@ def export(self, filepath): class ChannelDependency(Dependency): def __init__(self, model=None, dummy_input=None, traced_model=None): """ - This model analyze the channel dependencis between the conv + This model analyze the channel dependencies between the conv layers in a model. Parameters @@ -185,6 +185,109 @@ def dependency_sets(self): d_sets.append(tmp_set) return d_sets +def reshape_break_channel_dependency(op_node): + """ + The reshape operations such as (reshape, view, flatten) may break + the channel dependency. We need to check the input parameters of + these reshape operations to check if this reshape node will break + the channel dependency. However, it's complicated to analyze the the input + parameters for each reshape function and infer if it will break the channel + dependency. So currently, we just check if the input channel and the output + channel is the same, if so, then we can say the original reshape function + doesn't want to change the number of the channels, which means the channel + dependency is not broken. In contrast, the original reshap operation wants + to change the number of channels, so it breaks the channel dependency. + + Parameters + ---------- + opnode: NodePyOP + A Op node of the graph. + Returns + ------- + bool + If this operation will break the channel dependency. + """ + in_shape = op_node.auxiliary['in_shape'] + out_shape = op_node.auxiliary['out_shape'] + in_channel = in_shape[1] + out_channel = out_shape[1] + return in_channel != out_channel + +class InputChannelDependency(ChannelDependency): + """ + Some pruners may prune the input channel of the convolutional + layers. While pruning the input channel of the convolutional layers, + the layers that share the same input tensor should prune the same + channels, and we say these layers that share the same input tensor/channel + has the input channel dependency. If we only prune the input channel of one + layer in the dependency set, there will be a shape conflict for the other + layers in the same dependency set, which may trigger a runtime error. + Here we judge whether the application will truncate the dependency by analyzing + whether the number of channels before and after the operation has changed. + If not, the input channel dependency will be passed to the following nodes. + """ + + def __init__(self, model, dummy_input=None, traced_model=None): + """ + This model analyze the input channel dependencies between the conv + layers in a model. + + Parameters + ---------- + model : torch.nn.Module + The model to be analyzed. + data : torch.Tensor + The example input data to trace the network architecture. + traced_model : torch._C.Graph + if we alreay has the traced graph of the target model, we donnot + need to trace the model again. + """ + super(InputChannelDependency, self).__init__(model, dummy_input, traced_model) + + def _get_following_convs(self, tensor): + queue = [] + key_layers = [] + queue.extend(self.graph.input_to_node[tensor]) + while queue: + curnode = queue.pop(0) + if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear': + # find the first met conv + key_layers.append(curnode.name) + continue + elif curnode.op_type in RESHAPE_OPS: + # check if the reshape operation will break the channel dependency + if reshape_break_channel_dependency(curnode): + # reshape operations also breaks the dependency relationship + continue + successors = self.graph.find_successors(curnode.unique_name) + successors = [self.graph.name_to_node[name] for name in successors] + for layer in successors: + queue.append(layer) + return key_layers + + def build_dependency(self): + """ + Build the input channel dependencies. + The `InputChannelDependency` indicates the layers that have + dependencies when pruning the input channel of the conv layers. + In contrast, `ChannelDependency` indicates the dependent layers + when pruning the output channles of conv layers (for example, L1FilterPruner). + """ + # unpack the tuple or list manually + self.graph.unpack_manually() + for tensor in self.graph.input_to_node: + # start from this tensor, find all the conv layers that + # take this tensor as input. Similar to the `ChannelDependency` + # the conv layer will truncate the dependencies + layers = self._get_following_convs(tensor) + dependency_set = set(layers) + for layer in layers: + if layer in self.dependency: + dependency_set.update(self.dependency[layer]) + for layer in dependency_set: + self.dependency[layer] = dependency_set + + class CatPaddingDependency(ChannelDependency): def __init__(self, model=None, dummy_input=None, traced_model=None): super(CatPaddingDependency, self).__init__(model, dummy_input, traced_model)