Skip to content

Commit

Permalink
Merge pull request #94 from belgraviton/mxnet_attrs_correction
Browse files Browse the repository at this point in the history
MXNet parser 'attrs' corrections
  • Loading branch information
kitstar authored Mar 7, 2018
2 parents 9bba9ad + f13da3c commit e66cd5c
Showing 1 changed file with 23 additions and 54 deletions.
77 changes: 23 additions & 54 deletions mmdnn/conversion/mxnet/mxnet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,7 @@ def rename_BatchNorm(self, source_node):
# output shape
self.set_output_shape(source_node, IR_node)

layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
layer_attr = self._get_layer_attr(source_node)

# axis
if self.data_format in MXNetParser.channels_first or self.data_format == 'None':
Expand Down Expand Up @@ -605,11 +601,8 @@ def rename_Pooling(self, source_node):
# input edge
self.convert_inedge(source_node, IR_node)

layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
# attr
layer_attr = self._get_layer_attr(source_node)

# pooling type (sum not allowed yet)
pool_type = layer_attr.get("pool_type")
Expand Down Expand Up @@ -678,11 +671,8 @@ def rename_softmax(self, source_node):
# input edge
self.convert_inedge(source_node, IR_node)

layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
# attr
layer_attr = self._get_layer_attr(source_node)

# dim
if self.data_format in MXNetParser.channels_first or self.data_format == 'None':
Expand Down Expand Up @@ -713,11 +703,8 @@ def rename_Deconvolution(self, source_node):

dim = 0
layout = 'None'
layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
# attr
layer_attr = self._get_layer_attr(source_node)

# padding
if "pad" in layer_attr:
Expand Down Expand Up @@ -807,11 +794,7 @@ def rename_Embedding(self, source_node):
self.convert_inedge(source_node, IR_node)

# attr
layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
layer_attr = self._get_layer_attr(source_node)

# input_dim
IR_node.attr["input_dim"].i = int(layer_attr.get("input_dim"))
Expand All @@ -830,14 +813,13 @@ def rename_Embedding(self, source_node):
def rename_LeakyReLU(self, source_node):
# judge whether meaningful
assert "attr"
layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
if "act_type" in source_node.layer["attr"]:
if not source_node.layer["attr"]["act_type"] == "elu":
print("Warning: Activation Type %s is not supported yet." % source_node.layer["attr"]["act_type"])
return
# attr
layer_attr = self._get_layer_attr(source_node)

if "act_type" in layer_attr:
if not layer_attr["act_type"] == "elu":
print("Warning: Activation Type %s is not supported yet." % layer_attr["act_type"])
# return

IR_node = self.IR_graph.node.add()

Expand All @@ -848,7 +830,7 @@ def rename_LeakyReLU(self, source_node):
self.convert_inedge(source_node, IR_node)

# attr
layer_attr = source_node.layer["attr"]
# layer_attr = source_node.layer["attr"]

# alpha [exp(x) - alpha], but mxnet attr slope [slope*(exp(x) - 1)] when x < 0
if "slope" in layer_attr:
Expand Down Expand Up @@ -878,11 +860,8 @@ def rename_LRN(self, source_node):
# input edge
self.convert_inedge(source_node, IR_node)

layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
# attr
layer_attr = self._get_layer_attr(source_node)

# alpha
IR_node.attr["alpha"].f = float(layer_attr.get("alpha", "0.0001"))
Expand Down Expand Up @@ -911,11 +890,8 @@ def rename_Dropout(self, source_node):
# input edge
self.convert_inedge(source_node, IR_node)

layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
# attr
layer_attr = self._get_layer_attr(source_node)

# keep_prob
IR_node.attr["keep_prob"].f = float(layer_attr.get("p", "0.5"))
Expand All @@ -941,11 +917,8 @@ def rename_reshape(self, source_node):
# input edge
self.convert_inedge(source_node, IR_node)

layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
# attr
layer_attr = self._get_layer_attr(source_node)

# old API target_shape not support yet
shape = layer_attr.get("shape")
Expand Down Expand Up @@ -988,11 +961,7 @@ def rename_Concat(self, source_node):
self.convert_inedge(source_node, IR_node)

# attr
layer_attr = dict()
if "attr" in source_node.layer:
layer_attr = source_node.layer["attr"]
elif "param" in source_node.layer:
layer_attr = source_node.layer["param"]
layer_attr = self._get_layer_attr(source_node)

# dim
if self.data_format in MXNetParser.channels_first or self.data_format == 'None':
Expand Down

0 comments on commit e66cd5c

Please sign in to comment.