From 02a60d02d5d6a6d505554efa9871369dcb88a212 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 21 Sep 2017 21:52:08 -0700 Subject: [PATCH] [TOP] Add dense, batchnorm (#22) * [TOP] Add dense, batchnorm * update tvm --- nnvm/include/nnvm/compiler/op_attr_types.h | 7 +- nnvm/python/nnvm/compiler/build_module.py | 13 +- nnvm/python/nnvm/compiler/registry.py | 4 +- nnvm/python/nnvm/top/__init__.py | 1 + nnvm/python/nnvm/top/nn.py | 49 +++- nnvm/python/nnvm/top/tensor.py | 121 +++++++-- nnvm/python/nnvm/top/transform.py | 31 +++ nnvm/src/compiler/graph_fuse.cc | 17 +- nnvm/src/compiler/packed_func_ext.cc | 6 +- nnvm/src/compiler/simplify_batch_norm.cc | 19 +- nnvm/src/pass/print_graph_ir.cc | 2 +- .../compiler/test_simplify_batchnorm.py | 9 +- nnvm/tests/python/compiler/test_top_level1.py | 249 +++++++++++------- nnvm/tests/python/compiler/test_top_level2.py | 86 +++--- 14 files changed, 401 insertions(+), 213 deletions(-) create mode 100644 nnvm/python/nnvm/top/transform.py diff --git a/nnvm/include/nnvm/compiler/op_attr_types.h b/nnvm/include/nnvm/compiler/op_attr_types.h index c77720da1669..8381733c33a1 100644 --- a/nnvm/include/nnvm/compiler/op_attr_types.h +++ b/nnvm/include/nnvm/compiler/op_attr_types.h @@ -44,11 +44,14 @@ using TOpPattern = int; * \brief Computation description interface * \param attrs The attribute of the node. * \param inputs The input tensors(placeholders) + * \param out_info Tensors holding shape/type information about output, + & these are always placeholders. * \return The output description of the tensor. */ using FTVMCompute = std::function< - Array - (const NodeAttrs& attrs, const Array& inputs)>; + Array(const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info)>; /*! * \brief Build the computation schedule for diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 891a0b65729d..04b1f0a2e96f 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"): """ # pylint: disable=unused-argument cfg = BuildConfig.current + graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph.apply("InferShape") + if graph.json_attr("shape_num_unknown_nodes"): + raise ValueError("InferShape fails..") if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]: - graph = graph_attr.set_shape_inputs(graph, shape) - graph = graph.apply(["InferShape", "SimplifyBatchNormInference"]) + graph = graph.apply("SimplifyBatchNormInference") return graph @@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None): cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) shape, dtype = _update_shape_dtype(shape, dtype, params) + # Initial pass do shape type inference + ishape, _ = graph_util.infer_shape(graph, **shape) + shape.update(zip(graph.index.input_names, ishape)) + if not isinstance(dtype, str): + idtype, _ = graph_util.infer_dtype(graph, **dtype) + dtype.update(zip(graph.index.input_names, idtype)) # Apply optimization graph = optimize(graph, shape, dtype) # Precompute prune diff --git a/nnvm/python/nnvm/compiler/registry.py b/nnvm/python/nnvm/compiler/registry.py index d54b22d12681..c8094b1d345f 100644 --- a/nnvm/python/nnvm/compiler/registry.py +++ b/nnvm/python/nnvm/compiler/registry.py @@ -5,8 +5,10 @@ class OpPattern(object): ELEM_WISE = 0 BROADCAST = 1 + # Complex means we can fuse elemwise to it COMPLEX = 2 - EXTERN = 2 + # Extern means the op is not fusable + EXTERN = 3 _register_compute = tvm.get_global_func("nnvm._register_compute") _register_schedule = tvm.get_global_func("nnvm._register_schedule") diff --git a/nnvm/python/nnvm/top/__init__.py b/nnvm/python/nnvm/top/__init__.py index 21bf9ba33f43..11e776f1b0dd 100644 --- a/nnvm/python/nnvm/top/__init__.py +++ b/nnvm/python/nnvm/top/__init__.py @@ -2,3 +2,4 @@ from .attr_dict import AttrDict from . import tensor from . import nn +from . import transform diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index b4c0c44f0edc..5b0dfe2fe145 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -1,30 +1,37 @@ +# pylint: disable=invalid-name, unused-argument """Definition of nn ops""" from __future__ import absolute_import import tvm import topi from topi.util import get_const_int -from .tensor import schedule_elemwise +from .tensor import _fschedule_broadcast from ..compiler import registry as reg from ..compiler import OpPattern # relu @reg.register_compute("relu") -def compute_relu(_, inputs): +def compute_relu(attrs, inputs, _): """Compute definition of relu""" return topi.nn.relu(inputs[0]) -@reg.register_schedule("relu") -def schedule_relu(_, outs, target): - """Schedule definition of relu""" - return schedule_elemwise(_, outs, target) - +reg.register_schedule("relu", _fschedule_broadcast) reg.register_pattern("relu", OpPattern.ELEM_WISE) +# flatten +@reg.register_compute("flatten") +def compute_flatten(attrs, inputs, _): + """Compute definition of flatten""" + return topi.nn.flatten(inputs[0]) + +reg.register_schedule("flatten", _fschedule_broadcast) +reg.register_pattern("flatten", OpPattern.COMPLEX) + + # softmax @reg.register_compute("softmax") -def compute_softmax(attrs, inputs): +def compute_softmax(attrs, inputs, _): """Compute definition of softmax""" axis = attrs.get_int("axis") assert axis == -1, "only support axis == -1 for now" @@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target): # naive schedule return tvm.create_schedule([x.op for x in outs]) -reg.register_pattern("softmax", OpPattern.COMPLEX) +# Mark softmax as extern as we do not fuse it in call cases +reg.register_pattern("softmax", OpPattern.EXTERN) + + +# dense +@reg.register_compute("dense") +def compute_dense(attrs, inputs, _): + """Compute definition of dense""" + if attrs.get_bool("use_bias"): + return topi.nn.fully_connected_with_bias( + inputs[0], inputs[1], inputs[2]) + return topi.nn.fully_connected(inputs[0], inputs[1]) + +@reg.register_schedule("dense") +def schedule_dense(_, outs, target): + """Schedule definition of dense""" + if target == "cuda": + raise ValueError("fully_connected not yet implemented") + # naive schedule + return tvm.create_schedule([x.op for x in outs]) + +# register extern for now, change me when fusion is enabled. +reg.register_pattern("dense", OpPattern.EXTERN) # conv @reg.register_compute("conv2d") -def compute_conv2d(attrs, inputs): +def compute_conv2d(attrs, inputs, _): """Compute definition of conv2d""" padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index c49aeae0d19f..0259e99b6aee 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name +# pylint: disable=invalid-name, unused-argument """Tensor ops""" from __future__ import absolute_import @@ -8,15 +8,6 @@ from ..compiler import registry as reg from ..compiler import OpPattern -def schedule_elemwise(_, outs, target): - """Generic schedule for elemwise operation""" - if target == "cuda": - return topi.cuda.schedule_elemwise(outs) - assert target.startswith("llvm") - s = tvm.create_schedule([x.op for x in outs]) - tvm.schedule.AutoInlineInjective(s) - return s - def _schedule_broadcast(_, outs, target): """Generic schedule for binary bcast""" if target == "cuda": @@ -29,7 +20,7 @@ def _schedule_broadcast(_, outs, target): def _compute_binary_scalar(f): """auxiliary function""" @tvm.tag_scope("ewise") - def _compute(attrs, x): + def _compute(attrs, x, _): x = x[0] scalar = attrs.get_float("scalar") scalar = tvm.const(scalar, x.dtype) @@ -37,58 +28,132 @@ def _compute(attrs, x): return _compute +def _compute_unary(f): + """auxiliary function""" + def _compute(attrs, x, _): + return f(x[0]) + return _compute + + +def _compute_binary(f): + """auxiliary function""" + def _compute(attrs, x, _): + return f(x[0], x[1]) + return _compute + + _fschedule_broadcast = tvm.convert(_schedule_broadcast) # exp -reg.register_compute("exp", - lambda _, x: topi.exp(x[0])) +reg.register_compute("exp", _compute_unary(topi.exp)) reg.register_pattern("exp", OpPattern.ELEM_WISE) reg.register_schedule("exp", _fschedule_broadcast) +# sqrt +reg.register_compute("sqrt", _compute_unary(topi.sqrt)) +reg.register_pattern("sqrt", OpPattern.ELEM_WISE) +reg.register_schedule("sqrt", _fschedule_broadcast) + # log -reg.register_compute("log", - lambda _, x: topi.log(x[0])) +reg.register_compute("log", _compute_unary(topi.log)) reg.register_pattern("log", OpPattern.ELEM_WISE) reg.register_schedule("log", _fschedule_broadcast) # tanh -reg.register_compute("tanh", - lambda _, x: topi.tanh(x[0])) +reg.register_compute("tanh", _compute_unary(topi.tanh)) reg.register_pattern("tanh", OpPattern.ELEM_WISE) reg.register_schedule("tanh", _fschedule_broadcast) +# negative +reg.register_compute("negative", _compute_unary(topi.negative)) +reg.register_pattern("negative", OpPattern.ELEM_WISE) +reg.register_schedule("negative", _fschedule_broadcast) + # sigmoid -reg.register_compute("sigmoid", - lambda _, x: topi.sigmoid(x[0])) +reg.register_compute("sigmoid", _compute_unary(topi.sigmoid)) reg.register_pattern("sigmoid", OpPattern.ELEM_WISE) reg.register_schedule("sigmoid", _fschedule_broadcast) -# add scalar +# add_scalar reg.register_compute("__add_scalar__", _compute_binary_scalar(lambda x, y: x + y)) reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE) reg.register_schedule("__add_scalar__", _fschedule_broadcast) +# sub_calar +reg.register_compute("__sub_scalar__", + _compute_binary_scalar(lambda x, y: x - y)) +reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__sub_scalar__", _fschedule_broadcast) + +# rsub_scalar +reg.register_compute("__rsub_scalar__", + _compute_binary_scalar(lambda x, y: y - x)) +reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__rsub_scalar__", _fschedule_broadcast) + +# mul_scalar +reg.register_compute("__mul_scalar__", + _compute_binary_scalar(lambda x, y: x * y)) +reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__mul_scalar__", _fschedule_broadcast) + +# div_scalar +reg.register_compute("__div_scalar__", + _compute_binary_scalar(lambda x, y: x / y)) +reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__div_scalar__", _fschedule_broadcast) + +# rdiv_scalar +reg.register_compute("__rdiv_scalar__", + _compute_binary_scalar(lambda x, y: y / x)) +reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) + +# elemwise_add +reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add)) +reg.register_pattern("elemwise_add", OpPattern.BROADCAST) +reg.register_schedule("elemwise_add", _fschedule_broadcast) + +# elemwise_sub +reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub)) +reg.register_pattern("elemwise_sub", OpPattern.BROADCAST) +reg.register_schedule("elemwise_sub", _fschedule_broadcast) + +# elemwise_mul +reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul)) +reg.register_pattern("elemwise_mul", OpPattern.BROADCAST) +reg.register_schedule("elemwise_mul", _fschedule_broadcast) + +# elemwise_div +reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div)) +reg.register_pattern("elemwise_div", OpPattern.BROADCAST) +reg.register_schedule("elemwise_div", _fschedule_broadcast) + # broadcast_add -reg.register_compute("broadcast_add", - lambda _, x: topi.broadcast_add(x[0], x[1])) +reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add)) reg.register_pattern("broadcast_add", OpPattern.BROADCAST) reg.register_schedule("broadcast_add", _fschedule_broadcast) # broadcast_sub -reg.register_compute("broadcast_sub", - lambda _, x: topi.broadcast_sub(x[0], x[1])) +reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub)) reg.register_pattern("broadcast_sub", OpPattern.BROADCAST) reg.register_schedule("broadcast_sub", _fschedule_broadcast) # broadcast_mul -reg.register_compute("broadcast_mul", - lambda _, x: topi.broadcast_mul(x[0], x[1])) +reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul)) reg.register_pattern("broadcast_mul", OpPattern.BROADCAST) reg.register_schedule("broadcast_mul", _fschedule_broadcast) # broadcast_div -reg.register_compute("broadcast_div", - lambda _, x: topi.broadcast_div(x[0], x[1])) +reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div)) reg.register_pattern("broadcast_div", OpPattern.BROADCAST) reg.register_schedule("broadcast_div", _fschedule_broadcast) + +# broadcast_to +@reg.register_compute("broadcast_to") +def compute_softmax(attrs, inputs, out_info): + """Compute definition of softmax""" + return topi.broadcast_to(inputs[0], shape=out_info[0].shape) +reg.register_pattern("broadcast_to", OpPattern.BROADCAST) +reg.register_schedule("broadcast_to", _fschedule_broadcast) diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py new file mode 100644 index 000000000000..89e3c64a05ce --- /dev/null +++ b/nnvm/python/nnvm/top/transform.py @@ -0,0 +1,31 @@ +# pylint: disable=invalid-name, unused-argument +"""Tensor transformation ops""" +from __future__ import absolute_import + +import tvm +from .tensor import _fschedule_broadcast +from ..compiler import registry as reg +from ..compiler import OpPattern + +# Need add reshape, transpose + +def _flatten_index(indices, shape): + """flatten the index to 1D""" + idx = 0 + for i, value in enumerate(shape): + if i != 0: + idx *= value + idx = idx + indices[i] + return idx + +# reshape +@reg.register_compute("reshape") +def compute_reshape(attrs, inputs, out_info): + """Compute definition of softmax""" + # TODO(sxj) add support for general reshape + assert len(inputs[0].shape) == 1, "Only support 1d input for now" + oshape = out_info[0].shape + x = inputs[0] + return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape))) +reg.register_pattern("reshape", OpPattern.COMPLEX) +reg.register_schedule("reshape", _fschedule_broadcast) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index acf0f5677187..1daa5fd11394 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { if (inode.source->is_variable()) continue; int root_id = group_vec[nid]; FuseEntry& fe = fuse_vec[root_id]; - Array inputs; + Array inputs, out_info; // input loading for (const auto& e : inode.inputs) { if (group_vec[e.node_id] != root_id) { @@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { inputs.push_back(t); } } + // output hint + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + Array shape; + for (int64_t x : shape_vec[idx.entry_id(nid, i)]) { + CHECK_LE(x, static_cast(std::numeric_limits::max())); + shape.push_back(make_const(Int(32), x)); + } + out_info.push_back( + placeholder(shape, + TVMType2Type(dltype_vec[idx.entry_id(nid, i)]))); + } // get default Array out = fcompute[inode.source->op()]( - inode.source->attrs, inputs); + inode.source->attrs, inputs, out_info); CHECK_EQ(out.size(), inode.source->num_outputs()); - // schedule on root node, and use master's schedule if (nid != root_id) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { @@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { } } } + tvm::runtime::Module module = fbuild(funcs, target); // Final step: Remap the node, with given attribute const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op"); diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index e67f74312605..fed31be4033e 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -67,9 +67,11 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); Op& op = ::dmlc::Registry::Get()->__REGISTER_OR_GET__(args[0]); - auto fcompute = [f](const NodeAttrs& attrs, const Array& inputs) + auto fcompute = [f](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) -> Array { - TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs); + TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); if ((*ret.ptr >())->derived_from()) { return {ret.operator Tensor()}; } else { diff --git a/nnvm/src/compiler/simplify_batch_norm.cc b/nnvm/src/compiler/simplify_batch_norm.cc index 16d7557f29a5..fdee4dca2352 100644 --- a/nnvm/src/compiler/simplify_batch_norm.cc +++ b/nnvm/src/compiler/simplify_batch_norm.cc @@ -21,7 +21,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, nnvm::NodeEntry beta, nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_var, - int data_dim) { + TShape dshape) { CHECK(attrs.op); static const Op* bn_op = Op::Get("batch_norm"); CHECK(attrs.op == bn_op); @@ -57,19 +57,12 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, shift = MakeNode( "elemwise_add", bn_name + "_add_beta", {shift, beta}); } - // reshape to nhwc + // use broaodcast to reshape std::ostringstream oshape; - oshape << "("; - for (int i = 0; i < data_dim; ++i) { - if (i != 0) oshape << ", "; - if (i == param.axis) { - oshape << "-1"; - } else { - oshape << "1"; - } + for (dim_t i = 0; i < dshape.ndim(); ++i) { + dshape[i] = (i != param.axis) ? 1 : -1; } - oshape << ")"; - + oshape << dshape; scale = MakeNode("reshape", bn_name + "_sc_reshape", {scale}, {{"shape", oshape.str()}}); shift = MakeNode("reshape", bn_name + "_sh_reshape", @@ -98,7 +91,7 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) { n->inputs[2], n->inputs[3], n->inputs[4], - shape_vec[idx.entry_id(nid, 0)].ndim()); + shape_vec[idx.entry_id(nid, 0)]); return true; } else { return false; diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc index 6a42aabce616..52298c0a77a5 100644 --- a/nnvm/src/pass/print_graph_ir.cc +++ b/nnvm/src/pass/print_graph_ir.cc @@ -73,7 +73,7 @@ void PrintGraphIR_(Graph src, AttrPrinter fp = GetVectorPrinter(src, key); auto fprint = [&idx, key, fp]( uint32_t nid, std::ostream& os) { // NOLINT(*) - os << key << "="; + os << ", " << key << "="; fp(idx.entry_id(nid, 0), os); }; trigger.push_back(fprint); diff --git a/nnvm/tests/python/compiler/test_simplify_batchnorm.py b/nnvm/tests/python/compiler/test_simplify_batchnorm.py index ec6dfb86ac47..307c1348fbba 100644 --- a/nnvm/tests/python/compiler/test_simplify_batchnorm.py +++ b/nnvm/tests/python/compiler/test_simplify_batchnorm.py @@ -5,13 +5,13 @@ def test_simplify_batchnorm(): def simple_bn(x, gamma, beta, moving_mean, moving_var, - axis=1, epsilon=1e-5, dim=2): + axis=1, epsilon=1e-5, shape=None): # expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma) shift = sym.elemwise_add( sym.elemwise_mul(sym.negative(moving_mean), scale), beta) + shape = [-1 if i == axis else 1 for i in range(len(shape))] # for 2D - shape = tuple(1 if i != axis else -1 for i in range(dim)) scale = sym.reshape(scale, shape=shape) shift = sym.reshape(shift, shape=shape) return x * scale + shift @@ -26,15 +26,14 @@ def check(dim, axis, nstep): moving_var = sym.Variable("moving_var") moving_mean = sym.Variable("moving_mean") y1, y2 = x, x - + ishape = {"x": tuple(10 for i in range(dim))} for i in range(nstep): y1 = sym.batch_norm( y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var, - epsilon=eps, axis=axis, dim=dim) + epsilon=eps, axis=axis, shape=ishape["x"]) g = nnvm.graph.create(y1) g2 = nnvm.graph.create(y2) - ishape = {"x": tuple(10 for i in range(dim))} graph_attr.set_shape_inputs(g, ishape) g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") # Some prints for debug diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 14108d18db0e..5822e58f995b 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -6,19 +6,10 @@ import nnvm.compiler import nnvm.runtime -USE_GPU=True +def ctx_list(): + res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))] + return [x for x in res if x[1].exist] -def default_target(): - if USE_GPU: - return 'cuda' - else: - return 'llvm' - -def default_ctx(): - if USE_GPU: - return tvm.gpu(0) - else: - return tvm.cpu(0) def test_relu(): x = sym.Variable("x") @@ -26,20 +17,21 @@ def test_relu(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.maximum(data.asnumpy(), 0.0) - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + set_input("x", data) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + y_np = np.maximum(data.asnumpy(), 0.0) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) def test_exp(): @@ -48,20 +40,21 @@ def test_exp(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.exp(data.asnumpy()) - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + set_input("x", data) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + y_np = np.exp(data.asnumpy()) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) def test_log(): @@ -70,21 +63,22 @@ def test_log(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - with nnvm.compiler.build_config(opt_level=1): - graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.log(data.asnumpy()) - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + with nnvm.compiler.build_config(opt_level=1): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + set_input("x", data) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + y_np = np.log(data.asnumpy()) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) def test_tanh(): @@ -93,21 +87,22 @@ def test_tanh(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - with nnvm.compiler.build_config(opt_level=1): - graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy()) - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + with nnvm.compiler.build_config(opt_level=1): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + set_input("x", data) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy()) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) def test_sigmoid(): @@ -116,20 +111,21 @@ def test_sigmoid(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = 1.0 / (1.0 + np.exp(-data.asnumpy())) - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + set_input("x", data) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + y_np = 1.0 / (1.0 + np.exp(-data.asnumpy())) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) def test_softmax(): @@ -138,24 +134,79 @@ def test_softmax(): dtype = "float32" dshape = (10, 1000) oshape = dshape - with nnvm.compiler.build_config(opt_level=1): - graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = topi.testing.softmax_python(data.asnumpy()) - np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + with nnvm.compiler.build_config(opt_level=1): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + set_input("x", data) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + y_np = topi.testing.softmax_python(data.asnumpy()) + np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) + + +def test_dense(): + x = sym.Variable("x") + y = sym.dense(x, units=3, name="dense") + y = sym.flatten(y) + dtype = "float32" + shape = { + "x" : (10, 100), + "dense_weight" : (3, 100), + "dense_bias" : (3,), + } + graph, lib, _ = nnvm.compiler.build(y, "llvm", shape) + m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) + x_np = np.random.uniform(size=shape["x"]).astype(dtype) + w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype) + b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype) + res = tvm.nd.empty((10, 3)) + m.run(x=x_np, dense_weight=w_np, dense_bias=b_np) + m.get_output(0, res) + res_np = np.dot(x_np, w_np.T) + b_np + np.testing.assert_allclose( + res.asnumpy(), res_np, atol=1e-5, rtol=1e-5) + + +def test_batchnorm(): + x = sym.Variable("x") + beta = sym.Variable("beta") + gamma = sym.Variable("gamma") + moving_var = sym.Variable("moving_var") + moving_mean = sym.Variable("moving_mean") + shape = (10, 20) + eps = 1e-5 + dtype = "float32" + y = sym.batch_norm( + x, gamma, beta, moving_mean, moving_var, epsilon=eps) + + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape}) + m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) + x_np = np.random.uniform(size=shape).astype(dtype) + mean_np = np.random.uniform(size=shape[1]).astype(dtype) + var_np = np.random.uniform(size=shape[1]).astype(dtype) + gamma_np = np.random.uniform(size=shape[1]).astype(dtype) + beta_np = np.random.uniform(size=shape[1]).astype(dtype) + res = tvm.nd.empty(shape) + m.run(x=x_np, moving_mean=mean_np, moving_var=var_np, + gamma=gamma_np, beta=beta_np) + m.get_output(0, res) + res_np = (x_np - mean_np) / np.sqrt(var_np + eps) * gamma_np + beta_np + np.testing.assert_allclose( + res.asnumpy(), res_np, atol=1e-5, rtol=1e-5) if __name__ == "__main__": + test_batchnorm() + test_dense() test_relu() test_exp() test_log() diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 32d84158b336..7e39bff4017f 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -6,19 +6,9 @@ import nnvm.compiler import nnvm.runtime -USE_GPU=True - -def default_target(): - if USE_GPU: - return 'cuda' - else: - return 'llvm' - -def default_ctx(): - if USE_GPU: - return tvm.gpu(0) - else: - return tvm.cpu(0) +def ctx_list(): + res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))] + return [x for x in res if x[1].exist] def test_conv2d(): x = sym.Variable("x") @@ -29,23 +19,24 @@ def test_conv2d(): kshape = (10, 3, 3, 3) oshape = (1, 10, 18, 18) shape_dict = {"x": dshape} - graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) - set_input("x", data) - set_input("y_weight", kernel) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - c_np = topi.testing.conv2d_nchw_python( - data.asnumpy(), kernel.asnumpy(), 1, 1) - np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + set_input("x", data) + set_input("y_weight", kernel) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + c_np = topi.testing.conv2d_nchw_python( + data.asnumpy(), kernel.asnumpy(), 1, 1) + np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) def test_grouped_conv2d(): @@ -57,23 +48,24 @@ def test_grouped_conv2d(): kshape = (32, 1, 3, 3) oshape = (1, 32, 18, 18) shape_dict = {"x": dshape} - graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict) - m = nnvm.runtime.create(graph, lib, default_ctx()) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) - set_input("x", data) - set_input("y_weight", kernel) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - c_np = topi.testing.depthwise_conv2d_python_nchw( - data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') - np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + set_input("x", data) + set_input("y_weight", kernel) + # execute + run() + # get output + out = tvm.nd.empty(oshape, dtype) + get_output(0, out) + c_np = topi.testing.depthwise_conv2d_python_nchw( + data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') + np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) if __name__ == "__main__":