diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index e671acbe7d752..3fc3ca85184a2 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -205,6 +205,12 @@ def _upsampling(inputs, attrs): new_attrs = {'scale':int(scale)} return _get_nnvm_op('upsampling')(inputs[0], **new_attrs) +def _clip(inputs, attrs): + op_name, new_attrs = "clip", {} + new_attrs['a_min'] = _required_attr(attrs, 'a_min') + new_attrs['a_max'] = _required_attr(attrs, 'a_max') + return _get_nnvm_op(op_name)(*inputs, **new_attrs) + _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', @@ -248,6 +254,7 @@ def _upsampling(inputs, attrs): 'reshape' : _reshape, 'sum_axis' : _rename('sum'), 'UpSampling' : _upsampling, + 'clip' : _clip } def _convert_symbol(op_name, inputs, attrs, diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py b/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py index c9243d3610fa3..68215bb80aaac 100644 --- a/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py +++ b/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py @@ -71,7 +71,7 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', ** 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} - if not vgg_spec.has_key(num_layers): + if num_layers not in vgg_spec: raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)) layers, filters = vgg_spec[num_layers] data = mx.sym.Variable(name="data") diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index 0f9747538ce92..fca19a693855e 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -8,24 +8,41 @@ from nnvm.testing.config import ctx_list from nnvm import frontend import mxnet as mx +from mxnet import gluon +from mxnet.gluon.model_zoo import vision import model_zoo -def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)): +def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000), + gluon_impl=False, name=None): """Use name different from test to avoid let nose pick it up""" - def get_mxnet_output(symbol, x, dtype='float32'): - from collections import namedtuple - Batch = namedtuple('Batch', ['data']) - mod = mx.mod.Module(symbol, label_names=None) - mod.bind(data_shapes=[('data', x.shape)], for_training=False) - mod.init_params() - mod.forward(Batch([mx.nd.array(x.astype(dtype))])) - out = mod.get_outputs()[0].asnumpy() - args, auxs = mod.get_params() - return out, args, auxs + if gluon_impl: + def get_gluon_output(name, x): + net = vision.get_model(name) + net.collect_params().initialize(mx.init.Xavier()) + net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')), + inputs=mx.sym.var('data'), + params=net.collect_params()) + out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy() + return out, net_sym + else: + def get_mxnet_output(symbol, x, dtype='float32'): + from collections import namedtuple + Batch = namedtuple('Batch', ['data']) + mod = mx.mod.Module(symbol, label_names=None) + mod.bind(data_shapes=[('data', x.shape)], for_training=False) + mod.init_params() + mod.forward(Batch([mx.nd.array(x.astype(dtype))])) + out = mod.get_outputs()[0].asnumpy() + args, auxs = mod.get_params() + return out, args, auxs def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): - new_sym, params = frontend.from_mxnet(symbol, args, auxs) + if gluon_impl: + new_sym, params = frontend.from_mxnet(symbol) + else: + new_sym, params = frontend.from_mxnet(symbol, args, auxs) + dshape = x.shape shape_dict = {'data': dshape} with nnvm.compiler.build_config(opt_level=3): @@ -42,11 +59,17 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): # random input dtype = 'float32' x = np.random.uniform(size=data_shape) - mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) - assert "data" not in args - for target, ctx in ctx_list(): - tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) - np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) + if gluon_impl: + gluon_out, gluon_sym = get_gluon_output(name, x) + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype) + np.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5) + else: + mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) + assert "data" not in args + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) + np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) def test_forward_mlp(): mlp = model_zoo.mx_mlp @@ -91,6 +114,12 @@ def test_forward_fc_flatten(): except: pass +def test_forward_clip(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicity + mx_sym = mx.sym.clip(data, a_min=0, a_max=1) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -99,3 +128,4 @@ def test_forward_fc_flatten(): test_forward_rrelu() test_forward_softrelu() test_forward_fc_flatten() + test_forward_clip()