Skip to content

Commit

Permalink
[TOP][COMPILER] sum, min, max, transpose, fix dense (apache#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 389a00f commit 8aa0100
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 47 deletions.
23 changes: 13 additions & 10 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=invalid-name
"""MXNet symbol frontend."""
from __future__ import absolute_import as _abs
import json
Expand Down Expand Up @@ -155,14 +156,14 @@ def _split(attrs):
return op_name, new_attrs

_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
'broadcast_add', 'broadcast_div', 'broadcast_mul',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
'broadcast_add', 'broadcast_div', 'broadcast_mul',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']

_convert_map = {
'null' : _variable,
Expand Down Expand Up @@ -190,8 +191,8 @@ def _split(attrs):
}

def _convert_symbol(op_name, attrs,
identity_list=_identity_list,
convert_map=_convert_map):
identity_list=None,
convert_map=None):
"""Convert from mxnet op to nnvm op.
The converter must specify some conversions explicitly to
support gluon format ops such as conv2d...
Expand All @@ -214,6 +215,8 @@ def _convert_symbol(op_name, attrs,
(op_name, attrs)
Converted (op_name, attrs) for nnvm.
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
pass
elif op_name in convert_map:
Expand Down
1 change: 1 addition & 0 deletions nnvm/python/nnvm/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from . import tensor
from . import nn
from . import transform
from . import reduction
4 changes: 1 addition & 3 deletions nnvm/python/nnvm/top/attr_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# pylint: disable=invalid-name
"""Attr dictionary object used by schedule functions"""

import json
import tvm

_dict_get = tvm.get_global_func("nnvm.compiler._dict_get")
Expand Down Expand Up @@ -51,7 +49,7 @@ def get_int_tuple(self, key):
tuple : tuple of int
The result tuple
"""
return tuple(json.loads(self[key]))
return tuple(int(x) for x in self[key][1:-1].split(",") if x)

def get_int(self, key):
"""Get integer from attr dict
Expand Down
13 changes: 6 additions & 7 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def compute_relu(attrs, inputs, _):

# leaky_relu
@reg.register_compute("leaky_relu")
def compute_relu(attrs, inputs, _):
def compute_leaky_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0])
return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha"))

reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
Expand Down Expand Up @@ -62,20 +62,19 @@ def schedule_softmax(_, outs, target):
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.fully_connected_with_bias(
inputs[0], inputs[1], inputs[2])
return topi.nn.fully_connected(inputs[0], inputs[1])
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])

@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
"""Schedule definition of dense"""
if target == "cuda":
raise ValueError("fully_connected not yet implemented")
return topi.cuda.schedule_dense(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.OPAQUE)
reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)


# conv
Expand Down
47 changes: 47 additions & 0 deletions nnvm/python/nnvm/top/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# pylint: disable=invalid-name, unused-argument
"""Reduction ops"""
from __future__ import absolute_import

import tvm
import topi
import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern

def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
if target == "cuda":
return topi.cuda.schedule_reduce(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s

_fschedule_reduce = tvm.convert(_schedule_reduce)

def _compute_reduce(f):
"""auxiliary function"""
def _compute(attrs, inputs, out_info):
axis = attrs.get_int_tuple("axis")
keepdims = attrs.get_bool("keepdims")
if axis:
return f(inputs[0], axis=axis, keepdims=keepdims)
return f(inputs[0], keepdims=keepdims)
return _compute

# sum
reg.register_compute("sum", _compute_reduce(topi.sum))
reg.register_pattern("sum", OpPattern.COMM_REDUCE)
reg.register_schedule("sum", _fschedule_reduce)

# max
reg.register_compute("max", _compute_reduce(topi.max))
reg.register_pattern("max", OpPattern.COMM_REDUCE)
reg.register_schedule("max", _fschedule_reduce)

# min
reg.register_compute("min", _compute_reduce(topi.min))
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)
3 changes: 2 additions & 1 deletion nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def _schedule_injective(_, outs, target):
s[x].fuse(s[x].op.axis)
return s


def _compute_binary_scalar(f):
"""auxiliary function"""
@tvm.tag_scope("ewise")
@tvm.tag_scope(topi.tag.ELEMWISE)
def _compute(attrs, x, _):
x = x[0]
scalar = attrs.get_float("scalar")
Expand Down
16 changes: 13 additions & 3 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import tvm
import topi
from .tensor import _fschedule_broadcast
from .tensor import _fschedule_broadcast, _fschedule_injective
from ..compiler import registry as reg
from ..compiler import OpPattern

# Need add reshape, transpose
# Need add reshape
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
Expand All @@ -19,6 +19,16 @@ def compute_expand_dims(attrs, inputs, out_info):
reg.register_schedule("expand_dims", _fschedule_broadcast)


@reg.register_compute("transpose")
def compute_transpose(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
axes = attrs.get_int_tuple("axes")
axes = tuple(axes) if axes else None
return topi.transpose(inputs[0], axes)
reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective)


def _flatten_index(indices, shape):
"""flatten the index to 1D"""
idx = 0
Expand All @@ -38,4 +48,4 @@ def compute_reshape(attrs, inputs, out_info):
x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_broadcast)
reg.register_schedule("reshape", _fschedule_injective)
21 changes: 21 additions & 0 deletions nnvm/tests/python/compiler/test_op_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ def test_conv_ewise_injective():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


def test_injective_reduce_injective():
x = sym.Variable("x")
x = sym.flatten(x) + 1
y = sym.sum(x, axis=1)
dtype = "float32"
dshape = (32, 1, 18, 18)
shape_dict = {"x": dshape}

for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
assert graph.index.num_nodes == 2
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
c_np = np.sum(data.reshape(32, 18 * 18) + 1, axis=1)
# get output
out = m.get_output(0, tvm.nd.empty(c_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


if __name__ == "__main__":
test_injective_reduce_injective()
test_ewise_injective()
test_conv_ewise_injective()
43 changes: 20 additions & 23 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@

def test_relu():
x = sym.Variable("x")
y = sym.relu(x)
y = sym.leaky_relu(x, alpha=0.3) - 0.2
y = sym.relu(y)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.maximum(data.asnumpy(), 0.0)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
data = (data < 0) * data * 0.3 + (data>0) * data - 0.2
data = (data > 0) * data
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
np.testing.assert_allclose(out.asnumpy(), data, atol=1e-5, rtol=1e-5)


def test_exp():
Expand Down Expand Up @@ -157,17 +153,18 @@ def test_dense():
"dense_weight" : (3, 100),
"dense_bias" : (3,),
}
graph, lib, _ = nnvm.compiler.build(y, "llvm", shape)
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype)
b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype)
res = tvm.nd.empty((10, 3))
m.run(x=x_np, dense_weight=w_np, dense_bias=b_np)
m.get_output(0, res)
res_np = np.dot(x_np, w_np.T) + b_np
np.testing.assert_allclose(
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape)
m = nnvm.runtime.create(graph, lib, ctx)
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype)
b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype)
res = tvm.nd.empty((10, 3))
m.run(x=x_np, dense_weight=w_np, dense_bias=b_np)
m.get_output(0, res)
res_np = np.dot(x_np, w_np.T) + b_np
np.testing.assert_allclose(
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)


def test_batchnorm():
Expand Down
56 changes: 56 additions & 0 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.testing.config import test_ctx_list

def verify_transpose(dshape, axes):
x = sym.Variable("x")
if axes:
y = sym.transpose(x, axes=axes)
else:
y = sym.transpose(x)
y = y + 1
dtype = "float32"
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out_np = np.transpose(data.asnumpy(), axes=axes) + 1
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)


def verify_reduce(dshape, fnp, fsym, **kwargs):
x = sym.Variable("x")
y = fsym(x + 1, **kwargs)
dtype = "float32"
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# set input
data = np.random.uniform(size=dshape).astype(dtype)
out_np = fnp(data + 1, **kwargs)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)


def test_tranpose():
verify_transpose((2, 3, 4), (0, 2, 1))
verify_transpose((2, 3, 4), None)


def test_reduce():
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))


if __name__ == "__main__":
test_reduce()
test_tranpose()

0 comments on commit 8aa0100

Please sign in to comment.