diff --git a/model_compiler/python/nnir_to_nnef.py b/model_compiler/python/nnir_to_nnef.py index 03330b3605..276d150773 100644 --- a/model_compiler/python/nnir_to_nnef.py +++ b/model_compiler/python/nnir_to_nnef.py @@ -27,6 +27,7 @@ def generateGraph(graph,outputFolder,label): with open(fileName, 'wb') as f: f.write( \ """# This file is generated by nnir2nnef.py +version 1.0; graph nnir (%s) -> (%s) { """ % (', '.join([tensor.name for tensor in graph.inputs]), ', '.join([tensor.name for tensor in graph.outputs]))) @@ -60,30 +61,33 @@ def generateGraph(graph,outputFolder,label): """ %s = conv(%s, %s, %sstride=[%d,%d], dilation=[%d,%d], padding=[(%d,%d),(%d,%d)], groups=%d, border = 'ignore'); """ % (node.outputs[0], node.inputs[0], node.inputs[1], node.inputs[2] + ', ' if len(node.inputs) == 3 else '', \ strides[0], strides[1], dilations[0], dilations[1], pads[0], pads[1], pads[2], pads[3], group)) - elif node.type == 'avg_pool' or 'max_pool': + elif node.type == 'avg_pool' or node.type == 'max_pool': kernel_shape = node.attr.get('kernel_shape') pads = node.attr.get('pads') + padding = '(0,0),(0,0),(%d,%d),(%d,%d)' % (pads[0], pads[1], pads[2], pads[3]) if len(pads) != 0 else '' strides = node.attr.get('strides') + stride = '1,1,%d,%d' % (strides[0], strides[1]) if len(strides) != 0 else '' dilations = node.attr.get('dilations') + dilation = '1,1,%d,%d' % (dilations[0], dilations[1]) if len(dilations) != 0 else '' f.write( \ -""" %s = %s(%s, size=[%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[(%d,%d),(%d,%d)], border = 'ignore'); +""" %s = %s(%s, size=[1,1,%d,%d], stride=[%s], dilation=[%s], padding=[%s], border = 'ignore'); """ % (node.outputs[0], node.type, node.inputs[0], kernel_shape[0], kernel_shape[1], \ - strides[0], strides[1], dilations[0], dilations[1], pads[0], pads[1], pads[2], pads[3])) - elif node.type == 'relu' or 'softmax': + stride, dilation, padding)) + elif node.type == 'relu' or node.type == 'softmax': f.write( \ """ %s = %s(%s); """ % (node.outputs[0], node.type, node.inputs[0])) - elif node.type == 'add': + elif node.type == 'sum': f.write( \ -""" %s = %s(%s, %s); -""" % (node.outputs[0], node.type, node.inputs[0], node.inputs[1])) +""" %s = add(%s, %s); +""" % (node.outputs[0], node.inputs[0], node.inputs[1])) elif node.type == 'batch_norm': f.write( \ -""" %s = batch_normalization(%s, mean = %s, variance = %s, offset = %s, scale = %s, epsilon = %ef); +""" %s = batch_normalization(%s, %s, %s, %s, %s, epsilon = %e); """ % (node.outputs[0], node.inputs[0], node.inputs[3], node.inputs[4], node.inputs[2], node.inputs[1], node.attr.get('epsilon'))) elif node.type == 'gemm': f.write( \ -""" %s = matmul(%s, %s, trA = %s, trB = %s); +""" %s = matmul(%s, %s, transposeA = %s, transposeB = %s); """ % (node.outputs[0], node.inputs[0], node.inputs[1], \ 'true' if node.attr.get('transA') == 1 else 'false', \ 'true' if node.attr.get('transB') == 1 else 'false'))