Skip to content

Commit

Permalink
[Frontend][MXNet] ones zeros ones_like zeros_like ops support (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
imorinaga authored and tqchen committed Oct 21, 2018
1 parent e312964 commit 66e4af5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
19 changes: 15 additions & 4 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ def _lrn(inputs, attrs):
new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _ones(_, attrs):
op_name = "ones"
return _get_nnvm_op(op_name)(**attrs)

def _zeros(_, attrs):
op_name = "zeros"
return _get_nnvm_op(op_name)(**attrs)

_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
Expand All @@ -281,8 +289,8 @@ def _lrn(inputs, attrs):
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'slice_like', 'softmax', 'sum', 'tanh',
'transpose']
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'sum', 'tanh', 'transpose', 'zeros_like']

_convert_map = {
'_copy' : _rename('copy'),
Expand All @@ -294,6 +302,8 @@ def _lrn(inputs, attrs):
'_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_ones' : _ones,
'_zeros' : _zeros,
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
Expand Down Expand Up @@ -397,13 +407,14 @@ def _from_mxnet_impl(symbol, graph):
if node:
return node[output_index]
attr = symbol.list_attr()
# op_name = symbol.attr('op_name')
op_name = symbol.attr('op_name')
childs = symbol.get_children()
if childs is not None:
op_name = symbol.attr('op_name')
childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr) # no input symbol
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
Expand Down
26 changes: 26 additions & 0 deletions nnvm/tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,28 @@ def test_forward_lrn():
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))

def test_forward_ones():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, zeros)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_ones_like():
data = mx.sym.var('data')
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand All @@ -168,3 +190,7 @@ def test_forward_lrn():
test_forward_expand_dims()
test_forward_pooling()
test_forward_lrn()
test_forward_ones()
test_forward_zeros()
test_forward_ones_like()
test_forward_zeros_like()

0 comments on commit 66e4af5

Please sign in to comment.