From f13da3cbc606c5032dac0790f31cfe0b42f894d8 Mon Sep 17 00:00:00 2001 From: Vadzim Piatrou Date: Mon, 5 Mar 2018 16:53:20 +0300 Subject: [PATCH] mxnet parser attrs corrections --- mmdnn/conversion/mxnet/mxnet_parser.py | 77 ++++++++------------------ 1 file changed, 23 insertions(+), 54 deletions(-) diff --git a/mmdnn/conversion/mxnet/mxnet_parser.py b/mmdnn/conversion/mxnet/mxnet_parser.py index c3ae0ee2..189a21fc 100644 --- a/mmdnn/conversion/mxnet/mxnet_parser.py +++ b/mmdnn/conversion/mxnet/mxnet_parser.py @@ -557,11 +557,7 @@ def rename_BatchNorm(self, source_node): # output shape self.set_output_shape(source_node, IR_node) - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + layer_attr = self._get_layer_attr(source_node) # axis if self.data_format in MXNetParser.channels_first or self.data_format == 'None': @@ -605,11 +601,8 @@ def rename_Pooling(self, source_node): # input edge self.convert_inedge(source_node, IR_node) - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + # attr + layer_attr = self._get_layer_attr(source_node) # pooling type (sum not allowed yet) pool_type = layer_attr.get("pool_type") @@ -678,11 +671,8 @@ def rename_softmax(self, source_node): # input edge self.convert_inedge(source_node, IR_node) - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + # attr + layer_attr = self._get_layer_attr(source_node) # dim if self.data_format in MXNetParser.channels_first or self.data_format == 'None': @@ -713,11 +703,8 @@ def rename_Deconvolution(self, source_node): dim = 0 layout = 'None' - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + # attr + layer_attr = self._get_layer_attr(source_node) # padding if "pad" in layer_attr: @@ -807,11 +794,7 @@ def rename_Embedding(self, source_node): self.convert_inedge(source_node, IR_node) # attr - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + layer_attr = self._get_layer_attr(source_node) # input_dim IR_node.attr["input_dim"].i = int(layer_attr.get("input_dim")) @@ -830,14 +813,13 @@ def rename_Embedding(self, source_node): def rename_LeakyReLU(self, source_node): # judge whether meaningful assert "attr" - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - if "act_type" in source_node.layer["attr"]: - if not source_node.layer["attr"]["act_type"] == "elu": - print("Warning: Activation Type %s is not supported yet." % source_node.layer["attr"]["act_type"]) - return + # attr + layer_attr = self._get_layer_attr(source_node) + + if "act_type" in layer_attr: + if not layer_attr["act_type"] == "elu": + print("Warning: Activation Type %s is not supported yet." % layer_attr["act_type"]) + # return IR_node = self.IR_graph.node.add() @@ -848,7 +830,7 @@ def rename_LeakyReLU(self, source_node): self.convert_inedge(source_node, IR_node) # attr - layer_attr = source_node.layer["attr"] + # layer_attr = source_node.layer["attr"] # alpha [exp(x) - alpha], but mxnet attr slope [slope*(exp(x) - 1)] when x < 0 if "slope" in layer_attr: @@ -878,11 +860,8 @@ def rename_LRN(self, source_node): # input edge self.convert_inedge(source_node, IR_node) - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + # attr + layer_attr = self._get_layer_attr(source_node) # alpha IR_node.attr["alpha"].f = float(layer_attr.get("alpha", "0.0001")) @@ -911,11 +890,8 @@ def rename_Dropout(self, source_node): # input edge self.convert_inedge(source_node, IR_node) - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + # attr + layer_attr = self._get_layer_attr(source_node) # keep_prob IR_node.attr["keep_prob"].f = float(layer_attr.get("p", "0.5")) @@ -941,11 +917,8 @@ def rename_reshape(self, source_node): # input edge self.convert_inedge(source_node, IR_node) - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + # attr + layer_attr = self._get_layer_attr(source_node) # old API target_shape not support yet shape = layer_attr.get("shape") @@ -988,11 +961,7 @@ def rename_Concat(self, source_node): self.convert_inedge(source_node, IR_node) # attr - layer_attr = dict() - if "attr" in source_node.layer: - layer_attr = source_node.layer["attr"] - elif "param" in source_node.layer: - layer_attr = source_node.layer["param"] + layer_attr = self._get_layer_attr(source_node) # dim if self.data_format in MXNetParser.channels_first or self.data_format == 'None':