diff --git a/CMakeLists.txt b/CMakeLists.txt index cb9b2df2f284..494afbdff792 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -190,6 +190,7 @@ include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) +include(cmake/modules/contrib/HybridDump.cmake) add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) diff --git a/cmake/modules/contrib/HybridDump.cmake b/cmake/modules/contrib/HybridDump.cmake new file mode 100644 index 000000000000..c8d6d6e07756 --- /dev/null +++ b/cmake/modules/contrib/HybridDump.cmake @@ -0,0 +1,3 @@ +message(STATUS "Build with contrib.hybriddump") +file(GLOB HYBRID_CONTRIB_SRC src/contrib/hybrid/*.cc) +list(APPEND COMPILER_SRCS ${HYBRID_CONTRIB_SRC}) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 7043281fcafb..9e80d5ba72ff 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -197,6 +197,20 @@ You can also do loop-thread bind by writing code like this: a[tx] = b[tx] +Assert Statement +~~~~~~~~~~~~~~~~ + +Assert statement is supported, you can simply use it as it is in standard Python. + +.. code-block:: python + + assert cond, mesg + +.. note:: + + ``Assert`` is NOT a function call. Users are encouraged to use assert in the way + presented above --- condition followed by message. It fits both Python AST and HalideIR. + Keywords ~~~~~~~~ - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr`` diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 2e270bc3b217..139f13364271 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -292,6 +292,25 @@ def get_binds(args, binds=None): return binds, arg_list +def form_body(sch): + """According to the given schedule, form the raw body + Parameters + ---------- + sch : tvm.schedule.Schedule + The given scheduler to form the raw body + + Returns + ------- + The body formed according to the given schedule + """ + # normalize schedule first + sch = sch.normalize() + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + stmt = ir_pass.InjectPrefetch(stmt) + return stmt + + def lower(sch, args, name="default_function", @@ -337,11 +356,7 @@ def lower(sch, # Phase 0 if isinstance(sch, schedule.Schedule): - # normalize schedule first - sch = sch.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - stmt = ir_pass.InjectPrefetch(stmt) + stmt = form_body(sch) for f in lower_phase0: stmt = f(stmt) diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py index 6c137490c38e..645ef992833f 100644 --- a/python/tvm/hybrid/__init__.py +++ b/python/tvm/hybrid/__init__.py @@ -4,8 +4,77 @@ 1. Users can write some preliminary versions of the computation patterns have not been supported yet and verify it across the real execution and python semantic emulation. -2. Developers can build HalideIR by writing Python code. +2. So far, it is a text format dedicated to HalideIR Phase 0. Refer tvm.lower +for more details. A larger ambition of this module is to support all levels of +HalideIR. """ -from .api import script -from .parser import parse_python +# TODO(@were): Make this module more complete. +# 1. Support HalideIR dumping to Hybrid Script +# 2. Support multi-level HalideIR + +from __future__ import absolute_import as _abs + +from .._ffi.base import decorate +from .._ffi.function import _init_api +from ..build_module import form_body + +from .module import HybridModule +from .parser import source_to_op +from .util import _pruned_source + + +def script(pyfunc): + """Decorate a python function function as hybrid script. + + The hybrid function support emulation mode and parsing to + the internal language IR. + + Returns + ------- + hybrid_func : function + A decorated hybrid script function. + """ + def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring + from .util import _is_tvm_arg_types + if _is_tvm_arg_types(args): + src = _pruned_source(func) + return source_to_op(src, func.__globals__, args) + + from .runtime import _enter_hybrid_runtime, _restore_runtime + intersect = _enter_hybrid_runtime(func) + value = func(*args, **kwargs) + _restore_runtime(func, intersect) + return value + + return decorate(pyfunc, wrapped_func) + + +def build(sch, inputs, outputs, name="hybrid_func"): + """Dump the corrent schedule to hybrid module + + Parameters + ---------- + sch: Schedule + The schedule to be dumped + + inputs: An array of Tensors or Vars + The inputs of the function body + + outputs: An array of Tensors + The outputs of the function body + + Returns + ------- + module: HybridModule + The built results is wrapped in a HybridModule. + The usage of HybridModule is roughly the same as normal TVM-built modules. + """ + + stmt = form_body(sch) + src = _Dump(stmt, inputs, outputs, name) + + return HybridModule(src, name) + + +_init_api("tvm.hybrid") diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py deleted file mode 100644 index d43217ca5dfc..000000000000 --- a/python/tvm/hybrid/api.py +++ /dev/null @@ -1,43 +0,0 @@ -"""APIs of lowering the Python subset to HalideIR""" -from __future__ import absolute_import as _abs - -from .._ffi.base import decorate -from .. import _api_internal as _tvm_internal -from ..tensor import Tensor - -from .parser import parse_python -from .util import _pruned_source - - -def script(pyfunc): - """Decorate a python function function as hybrid script. - - The hybrid function support emulation mode and parsing to - the internal language IR. - - Returns - ------- - hybrid_func : function - A decorated hybrid script function. - """ - def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring - from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types - if _is_tvm_arg_types(args): - src = _pruned_source(func) - parser = parse_python(src, func.__globals__, args) - - input_tensors = [] - for i in args: - if isinstance(i, Tensor): - input_tensors.append(i) - op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, - parser.outputs, parser.parsed_body) - res = [op.output(i) for i in range(len(parser.outputs))] - return res[0] if len(res) == 1 else res - - intersect = _enter_hybrid_runtime(func) - value = func(*args, **kwargs) - _restore_runtime(func, intersect) - return value - - return decorate(pyfunc, wrapped_func) diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 3fd472c57afc..84ae537d49ab 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -8,6 +8,7 @@ from .. import ir_pass from ..stmt import For from .util import _internal_assert +from ..intrin import call_pure_intrin #pylint: disable=redefined-builtin @@ -104,3 +105,29 @@ def len(func_id, args): except: #pylint: disable=bare-except _internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len") return _api.convert(args[0].shape[0]) + + +def _cast(func_id, args): + _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \ + "Only one expression can be cast") + return _make.Cast(func_id, args[0]) + +float16 = float32 = float64 = _cast #pylint: disable=invalid-name +int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name +uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name + + +def ceil_div(func_id, args): + _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!") + _internal_assert(args.__len__() == 2, "2 arguments expected for division!") + _internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div") + _internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div") + a, b = args[0], args[1] + return (a + b - 1) / b + + +def likely(func_id, args): + _internal_assert(args.__len__() == 1, \ + "Only one expression can be likely") + _internal_assert(func_id == "likely", "This function cannot be directly invoked!") + return call_pure_intrin(args[0].dtype, 'likely', *args) diff --git a/python/tvm/hybrid/module.py b/python/tvm/hybrid/module.py new file mode 100644 index 000000000000..01557ba8b179 --- /dev/null +++ b/python/tvm/hybrid/module.py @@ -0,0 +1,100 @@ +"""Methods and data structures to support dumping HalideIR to Hybrid Script. +This allows users to do quick hack to generated HalideIR and cast it back to +TVM modules. + +To enable this feature, you need to build with -DUSE_HYBRID_DUMP=ON. +""" + +import ast +import imp + +from ..contrib import util +from .util import _internal_assert +from .util import _is_tvm_arg_types +from .parser import source_to_op + + +class HybridModule(object): + """The usage of Hybrid Module is very similar to conventional TVM module, + but conventional TVM module requires a function body which is already fully + lowered. This contradicts to the fact that Hybrid Module is originally a text + format for Phase 0 HalideIR. Thus, a totally separated module is defined.""" + + + def __init__(self, src=None, name=None): + """The constructor of this a hybrid module + + Parameters + ---------- + src : str + The source code of this module + + name : str + The name of this module + """ + self.src_ = self.name = self.func_ = self.root_ = None + if src is not None: + temp = util.tempdir() + dst = temp.relpath("script.py") + with open(dst, 'w') as f: + f.write("import tvm\n@tvm.hybrid.script\n%s" % src) + + if name is not None: + self.name = name + self.load(dst) + + + def __call__(self, *args): + if _is_tvm_arg_types(args): + return source_to_op(self.root_, globals(), args) + return self.func_(*args) + + + def get_source(self): + return self.src_ + + + def save(self, path): + if not path.endswith('.py'): + path = path + '.py' + with open(path, 'w') as f: + f.write(self.src_) + + + def load(self, path): + """Load the module from a python file + + Parameters + ---------- + path : str + Path to the given python file + """ + with open(path, 'r') as f: + self.src_ = f.read() + + src = self.src_ + + class FindFunc(ast.NodeVisitor): + """ Find the function in module to be loaded module. """ + #pylint: disable=invalid-name + def __init__(self): + self.name = None + self.root = None + + + def visit_FunctionDef(self, node): + _internal_assert(self.name is None, "For now, only one function supported!") + self.name = node.name + _internal_assert(self.root is None, "For now, only one function supported!") + self.root = node + + root = ast.parse(src) + finder = FindFunc() + finder.visit(root) + _internal_assert(finder.name is not None and finder.root is not None, \ + "No function found!") + if self.name is None: + self.name = finder.name + self.root_ = finder.root + py_module = imp.load_source(self.name, path) + self.func_ = getattr(py_module, self.name) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 1e3fe3301191..b9d64866b305 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -17,28 +17,36 @@ from ..api import any as _any from ..container import Array from ..tensor import Tensor, Operation +from .. import _api_internal as _tvm_internal from .. import expr as _expr +from .. import stmt as _stmt from .. import make as _make from .. import api as _api from .. import ir_pass as _ir_pass -def pack_list_to_block(lst): - if len(lst) == 1: +def concat_list_to_block(lst): + """Concatenate a list of Python IR nodes to HalideIR Block""" + n = len(lst) + if n == 1: return lst[0] - body = lst[0] - for i in lst[1:]: - body = _make.Block(body, i) + body = lst[n - 1] + for i in range(1, n): + stmt = lst[n - 1 - i] + if isinstance(stmt, _stmt.AssertStmt): + body = _make.AssertStmt(stmt.condition, stmt.message, body) + else: + body = _make.Block(stmt, body) return body def visit_list_to_block(visit, lst): - """Convert a list of Python IR nodes to HalideIR Block""" + """Visit and concatenate a list of Python IR nodes to HalideIR Block""" lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] if not lst: return util.make_nop() - return pack_list_to_block(lst) + return concat_list_to_block(lst) class Symbol(Enum): @@ -441,7 +449,7 @@ def visit_For(self, node): body = visit_list_to_block(self.visit, node.body) body = self.wrap_up_realize(node, body) bodies.append(body) - return pack_list_to_block(bodies) + return concat_list_to_block(bodies) elif iter_var is None: _internal_assert(for_type is not None, "The loop bind function parse error!") @@ -496,15 +504,22 @@ def visit_Str(self, node): return node.s + def visit_Assert(self, node): + test = self.visit(node.test) + mesg = _api.convert(self.visit(node.msg)) + return _make.AssertStmt(test, mesg, util.make_nop()) + + def parse_python(src, symbols, args): """The helper function of calling the AST visitor Parameters ---------- - src : str - The source code of the function to be parsed. + src : ast.node or str + If an ast.node, then directly lower it. + If a str, then parse it to ast and lower it. - src : str + symbols : str The symbol list of the global context of the function. args : list of Tensors or Vars @@ -517,9 +532,44 @@ def parse_python(src, symbols, args): root : Stmt The result Halide IR and the parser class instance. """ - root = ast.parse(src) + root = ast.parse(src) if isinstance(src, str) else src + _internal_assert(root, ast.AST) var_usage = determine_variable_usage(root, args, symbols) parser = HybridParser(args, var_usage, symbols) parser.parsed_body = parser.visit(root) _internal_assert(parser.returned, 'No valid return found in the function body!') return parser + + +def source_to_op(src, symbols, args): + """Another level of wrapper + + Parameters + ---------- + src : ast.node or str + If an ast.node, then directly lower it. + If a str, then parse it to ast and lower it. + + symbols : str + The symbol list of the global context of the function. + + args : list of Tensors or Vars + The argument lists to the function. + It is NOT encouraged to write a function without arguments. + It is NOT encouraged to write a function with side effect. + + Returns + ------- + res : list of output tensors + The result of output tensors of the formed OpNode. + """ + parser = parse_python(src, symbols, args) + + input_tensors = [] + for i in args: + if isinstance(i, Tensor): + input_tensors.append(i) + op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, + parser.outputs, parser.parsed_body) + res = [op.output(i) for i in range(len(parser.outputs))] + return res[0] if len(res) == 1 else res diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/runtime.py similarity index 63% rename from python/tvm/hybrid/intrin.py rename to python/tvm/hybrid/runtime.py index cb6d0fdb74b8..293e069c24ea 100644 --- a/python/tvm/hybrid/intrin.py +++ b/python/tvm/hybrid/runtime.py @@ -73,7 +73,6 @@ def sigmoid(x): HYBRID_GLOBALS = { - 'len' : len, 'unroll' : range, 'vectorize' : range, 'parallel' : range, @@ -88,4 +87,37 @@ def sigmoid(x): 'exp' : numpy.exp, 'sigmoid' : sigmoid, 'popcount' : popcount, + 'likely' : lambda cond: cond, + 'uint8' : numpy.uint8, + 'uint16' : numpy.uint16, + 'uint32' : numpy.uint32, + 'uint64' : numpy.uint64, + 'int8' : numpy.int8, + 'int16' : numpy.int16, + 'int32' : numpy.int32, + 'int64' : numpy.int64, + 'float16' : numpy.float16, + 'float32' : numpy.float32, + 'float64' : numpy.float64, + 'ceil_div' : lambda a, b: (a + b - 1) / b } + + +def _enter_hybrid_runtime(func): + """Put hybrid runtime variables into the global scope""" + _globals = func.__globals__ + intersect = [] + for elem in list(HYBRID_GLOBALS.keys()): + if elem in _globals.keys(): + intersect.append((elem, _globals[elem])) + _globals[elem] = HYBRID_GLOBALS[elem] + return intersect + + +def _restore_runtime(func, intersect): + """Rollback the modification caused by hybrid runtime""" + _globals = func.__globals__ + for elem in list(HYBRID_GLOBALS.keys()): + _globals.pop(elem) + for k, v in intersect: + _globals[k] = v diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 44222d2d80f7..56190a82765e 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -5,14 +5,13 @@ import logging import sys import numpy -from .intrin import HYBRID_GLOBALS -from .._ffi.base import numeric_types from .. import api as _api from .. import make as _make from .. import expr as _expr from .. import stmt as _stmt -from ..container import Array +from .._ffi.base import numeric_types from ..tensor import Tensor +from ..container import Array #pylint: disable=invalid-name @@ -20,6 +19,7 @@ tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) + def _internal_assert(cond, err): """Simplify the code segment like if not XXX then raise an error""" if not cond: @@ -52,6 +52,23 @@ def _pruned_source(func): raise err +def replace_io(body, rmap): + """Replacing tensors usage according to the dict given""" + from .. import ir_pass + + def replace(op): + if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): + buf = rmap[op.func] + return _make.Provide(buf.op, op.value_index, op.value, op.args) + elif isinstance(op, _expr.Call) and op.func in rmap.keys(): + buf = rmap[op.func] + return _make.Call(buf.dtype, buf.name, op.args, \ + _expr.Call.Halide, buf.op, buf.value_index) + return None + + return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) + + def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" @@ -68,40 +85,3 @@ def _is_tvm_arg_types(args): _internal_assert(isinstance(elem, np_arg_types), \ "Expect a numpy type but %s get!" % str(type(elem))) return False - - -def _enter_hybrid_runtime(func): - """Put hybrid runtime variables into the global scope""" - _globals = func.__globals__ - intersect = [] - for elem in list(HYBRID_GLOBALS.keys()): - if elem in _globals.keys(): - intersect.append((elem, _globals[elem])) - _globals[elem] = HYBRID_GLOBALS[elem] - return intersect - - -def _restore_runtime(func, intersect): - """Rollback the modification caused by hybrid runtime""" - _globals = func.__globals__ - for elem in list(HYBRID_GLOBALS.keys()): - _globals.pop(elem) - for k, v in intersect: - _globals[k] = v - - -def replace_io(body, rmap): - """Replacing tensors usage according to the dict given""" - from .. import ir_pass - - def replace(op): - if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): - buf = rmap[op.func] - return _make.Provide(buf.op, op.value_index, op.value, op.args) - elif isinstance(op, _expr.Call) and op.func in rmap.keys(): - buf = rmap[op.func] - return _make.Call(buf.dtype, buf.name, op.args, \ - _expr.Call.Halide, buf.op, buf.value_index) - return None - - return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index eb893a7f22a1..50b610567c74 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -2,7 +2,7 @@ import ast import sys -from .intrin import HYBRID_GLOBALS +from .runtime import HYBRID_GLOBALS from .util import _internal_assert @@ -45,7 +45,7 @@ def visit_Call(self, node): _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id") func_id = node.func.id _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \ - ['range', 'max', 'min'] + \ + ['range', 'max', 'min', 'len'] + \ list(self.symbols.keys()), \ "Function call id not in intrinsics' list") for elem in node.args: diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 32bb5f9d6617..f9190123a0a9 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -103,6 +103,8 @@ Target CreateTarget(const std::string& target_name, t->device_type = kDLCPU; } else if (target_name == "ext_dev") { t->device_type = kDLExtDev; + } else if (target_name == "hybrid") { + t->device_type = kDLCPU; } else { LOG(ERROR) << "Unknown target name " << target_name; return target::stackvm(); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc new file mode 100644 index 000000000000..2117d471eeee --- /dev/null +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -0,0 +1,491 @@ +/*! Copyright (c) 2019 by Contributors + * \file codegen_hybrid.cc + */ +#include +#include +#include "codegen_hybrid.h" + +namespace tvm { +namespace contrib { + +using namespace ir; + +std::string dot_to_underscore(std::string s) { + for (auto &ch : s) + if (ch == '.') ch = '_'; + return s; +} + +std::string CodeGenHybrid::GetUniqueName(std::string prefix) { + prefix = dot_to_underscore(prefix); + auto it = ids_allocated_.find(prefix); + if (it != ids_allocated_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (ids_allocated_.count(name) == 0) { + prefix = name; + break; + } + } + } + ids_allocated_[prefix] = 0; + return prefix; +} + +std::string CodeGenHybrid::Finish() { + return stream.str(); +} + +void CodeGenHybrid::PrintType(Type t, std::ostream &os) { + if (t.is_float()) { + os << "float"; + CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); + } else if (t.is_int()) { + os << "int"; + CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); + } else { + CHECK(t.is_uint()) << "Unsupported type " << t; + os << "uint"; + CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); + } + os << t.bits(); +} + +void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) + os << op->value; +} +void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) + PrintType(op->type, os); + os << "(" << op->value << ")"; +} +void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) + PrintType(op->type, os); + os << "(" << std::setprecision(20) << op->value << ")"; +} +void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) + os << "'" << op->value << "'"; +} + +template +inline void PrintBinaryExpr(const T* op, + const char *opstr, + std::ostream& os, // NOLINT(*) + CodeGenHybrid* p) { + CHECK(op->type.lanes() == 1) << "vec bin op not implemented"; + if (isalpha(opstr[0])) { + os << opstr << '('; + p->PrintExpr(op->a, os); + os << ", "; + p->PrintExpr(op->b, os); + os << ')'; + } else { + os << '('; + p->PrintExpr(op->a, os); + if (!strcmp(opstr, "&&")) opstr = "and"; + if (!strcmp(opstr, "||")) opstr = "or"; + os << ' ' << opstr << ' '; + p->PrintExpr(op->b, os); + os << ')'; + } +} + +inline void PrintBinaryIntrinsitc(const Call* op, + const char *opstr, + std::ostream& os, // NOLINT(*) + CodeGenHybrid* p) { + CHECK(op->type.lanes() == 1) << "vec bin intrin not implemented"; + CHECK_EQ(op->args.size(), 2U); + os << '('; + p->PrintExpr(op->args[0], os); + os << opstr; + p->PrintExpr(op->args[1], os); + os << ')'; +} + +void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) + if (op->type == op->value.type()) { + PrintExpr(op->value, stream); + } else { + PrintType(op->type, os); + os << "("; + PrintExpr(op->value, os); + os << ")"; + } +} + +void CodeGenHybrid::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) + os << GetVarID(op); +} +void CodeGenHybrid::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "+", os, this); +} +void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "-", os, this); +} +void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "*", os, this); +} +void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) + if (op->type.is_int()) + PrintBinaryExpr(op, "//", os, this); + else + PrintBinaryExpr(op, "/", os, this); +} +void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "%", os, this); +} +void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "min", os, this); +} +void CodeGenHybrid::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "max", os, this); +} +void CodeGenHybrid::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "==", os, this); +} +void CodeGenHybrid::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "!=", os, this); +} +void CodeGenHybrid::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "<", os, this); +} +void CodeGenHybrid::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "<=", os, this); +} +void CodeGenHybrid::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, ">", os, this); +} +void CodeGenHybrid::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, ">=", os, this); +} +void CodeGenHybrid::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "&&", os, this); +} +void CodeGenHybrid::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "||", os, this); +} +void CodeGenHybrid::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*) + os << "not "; + PrintExpr(op->a, os); +} + +void CodeGenHybrid::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) + if (op->call_type == Call::Halide) { + os << GetTensorID(op->func, op->value_index); + os << "["; + for (size_t i = 0; i < op->args.size(); ++i) { + if (i) os << ", "; + std::stringstream idx; + PrintExpr(op->args[i], idx); + os << idx.str(); + } + os << "]"; + } else if (op->is_intrinsic(Call::bitwise_and)) { + PrintBinaryIntrinsitc(op, "&", os, this); + } else if (op->is_intrinsic(Call::bitwise_xor)) { + PrintBinaryIntrinsitc(op, "^", os, this); + } else if (op->is_intrinsic(Call::bitwise_or)) { + PrintBinaryIntrinsitc(op, "|", os, this); + } else if (op->is_intrinsic(Call::shift_left)) { + PrintBinaryIntrinsitc(op, "<<", os, this); + } else if (op->is_intrinsic(Call::shift_right)) { + PrintBinaryIntrinsitc(op, ">>", os, this); + } else if (op->is_intrinsic(Call::bitwise_not)) { + CHECK_EQ(op->args.size(), 1U); + os << "(~"; + PrintExpr(op->args[0], os); + os << ')'; + } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + PrintExpr(op->args[1], os); + os << " if "; + PrintExpr(op->args[0], os); + os << " else "; + PrintExpr(op->args[2], os); + } else { + os << op->name << "("; + for (size_t i = 0; i < op->args.size(); i++) { + PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; + } +} + +void CodeGenHybrid::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Phase 0 has no Load(s)!"; +} + +void CodeGenHybrid::VisitStmt_(const Store* op) { + LOG(FATAL) << "Phase 0 has no Store(s)!"; +} + +void CodeGenHybrid::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Phase 0 has no Let(s)!"; +} + +void CodeGenHybrid::VisitStmt_(const Allocate* op) { + LOG(FATAL) << "Phase 0 has no Allocate(s)!"; +} + +void CodeGenHybrid::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Ramp to be supported yet"; +} + +void CodeGenHybrid::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Broadcast: not supported "; +} + +void CodeGenHybrid::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) + PrintExpr(op->true_value, os); + os << " if "; + PrintExpr(op->condition, os); + os << " else "; + PrintExpr(op->false_value, os); + os << "\n"; +} + +void CodeGenHybrid::VisitStmt_(const LetStmt* op) { + std::string value = PrintExpr(op->value); + stream << GetVarID(op->var.get()) << " = " << value << ";\n"; + PrintStmt(op->body); +} + +void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { + if (op->attr_key == ir::attr::thread_extent) { + auto iter_var = op->node.as(); + CHECK(iter_var); + binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); + PrintIndent(); + stream << "for " << binds_[iter_var->var.get()] << " in bind('" + << iter_var->var->name_hint << "', "; + PrintExpr(op->value, stream); + stream << "):\n"; + indent_ += tab_; + PrintStmt(op->body); + indent_ -= tab_; + } else if (op->attr_key == ir::attr::realize_scope) { + auto v = FunctionRef(op->node.node_); + alloc_storage_scope_[v] = op->value.as()->value; + PrintStmt(op->body); + } else { + // For now we ignore the unsupported AttrStmt + PrintStmt(op->body); + } +} + +void CodeGenHybrid::VisitStmt_(const Realize *op) { + CHECK(alloc_storage_scope_.count(op->func)); + if (!alloc_storage_scope_[op->func].empty()) { + PrintIndent(); + stream << GetTensorID(op->func, op->value_index) << " = allocate(("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + if (i) stream << ", "; + stream << PrintExpr(op->bounds[i]->extent); + } + if (op->bounds.size() == 1) stream << ", "; + stream << "), '"; + PrintType(op->type, stream); + stream << "', '"; + stream << alloc_storage_scope_[op->func] << "')\n"; + } + PrintStmt(op->body); +} + +void CodeGenHybrid::VisitStmt_(const AssertStmt* op) { + PrintIndent(); + stream << "assert "; + PrintExpr(op->condition, stream); + stream << ", "; + PrintExpr(op->message, stream); + stream << "\n"; + PrintStmt(op->body); +} + +void CodeGenHybrid::VisitStmt_(const Provide* op) { + PrintIndent(); + stream << GetTensorID(op->func, op->value_index); + stream << "["; + for (size_t i = 0; i < op->args.size(); ++i) { + if (i) stream << ", "; + PrintExpr(op->args[i], stream); + } + stream << "] = "; + PrintExpr(op->value, stream); + stream << "\n"; +} + +void CodeGenHybrid::VisitStmt_(const For* op) { + std::string extent = PrintExpr(op->extent); + PrintIndent(); + std::string vid = GetVarID(op->loop_var.get()); + stream << "for " << vid << " in " << "range(" << extent << "):\n"; + indent_ += tab_; + PrintStmt(op->body); + indent_ -= tab_; +} + +bool is_noop(const Stmt &stmt) { + if (!stmt.defined()) + return true; + if (auto eval = stmt.as()) + return is_const(eval->value); + return false; +} + +void CodeGenHybrid::VisitStmt_(const IfThenElse* op) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + stream << "if " << cond << ":\n"; + indent_ += tab_; + PrintStmt(op->then_case); + indent_ -= tab_; + + if (!is_noop(op->else_case)) { + PrintIndent(); + stream << "else:\n"; + indent_ += tab_; + PrintStmt(op->else_case); + indent_ -= tab_; + } +} + +void CodeGenHybrid::VisitStmt_(const Block *op) { + PrintStmt(op->first); + if (op->rest.defined()) PrintStmt(op->rest); +} + +void CodeGenHybrid::VisitStmt_(const Evaluate *op) { + if (is_const(op->value)) return; + std::string str = PrintExpr(op->value); + if (!str.empty()) + stream << str << "\n"; +} + +void CodeGenHybrid::VisitStmt_(const ProducerConsumer *op) { + PrintStmt(op->body); +} + +void CodeGenHybrid::PrintIndent() { + stream << std::string(indent_, ' '); +} + +std::string CodeGenHybrid::GetVarID(const Variable *v) { + if (binds_.count(v)) + return binds_[v]; + auto key = std::make_pair(v->GetNodePtr().get(), 0); + if (id_map_.count(key)) { + return id_map_[key]; + } + return id_map_[key] = GetUniqueName(v->name_hint); +} + +std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) { + auto key = std::make_pair(func.get(), value_index); + if (id_map_.count(key)) { + return id_map_[key]; + } + std::string name_hint = func->func_name(); + if (func->num_outputs() > 1) { + name_hint += "_v" + std::to_string(value_index); + } + return id_map_[key] = GetUniqueName(name_hint); +} + +void CodeGenHybrid::ReserveKeywords() { + GetUniqueName("def"); + GetUniqueName("for"); + GetUniqueName("in"); + GetUniqueName("range"); + GetUniqueName("unroll"); + GetUniqueName("const_range"); + GetUniqueName("parallel"); + GetUniqueName("vectorize"); + GetUniqueName("bind"); + GetUniqueName("threadIdx.x"); + GetUniqueName("threadIdx.y"); + GetUniqueName("threadIdx.z"); + GetUniqueName("blockIdx.x"); + GetUniqueName("blockIdx.y"); + GetUniqueName("blockIdx.z"); + GetUniqueName("vthread"); + GetUniqueName("allocate"); + GetUniqueName("output_tensor"); + GetUniqueName("sqrt"); + GetUniqueName("log"); + GetUniqueName("tanh"); + GetUniqueName("power"); + GetUniqueName("exp"); + GetUniqueName("sigmoid"); + GetUniqueName("popcount"); + GetUniqueName("likely"); + GetUniqueName("int8"); + GetUniqueName("int16"); + GetUniqueName("int32"); + GetUniqueName("int64"); + GetUniqueName("uint8"); + GetUniqueName("uint16"); + GetUniqueName("uint32"); + GetUniqueName("uint64"); + GetUniqueName("float16"); + GetUniqueName("float32"); + GetUniqueName("float64"); + GetUniqueName("ceil_div"); +} + +void CodeGenHybrid::DumpStmt(const Stmt &stmt, + const Array &inputs, + const Array &outputs, + const std::string &name) { + ReserveKeywords(); + GetUniqueName(name); + + stream << "def " << name << "("; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i) stream << ", "; + if (auto tensor = inputs[i].as()) { + stream << GetTensorID(tensor->op, tensor->value_index); + } else { + auto var = inputs[i].as(); + CHECK(var) << "Input should either be a tensor or a variable!"; + stream << GetVarID(var); + } + } + stream << "):\n"; + indent_ += tab_; + for (size_t i = 0; i < outputs.size(); ++i) { + PrintIndent(); + stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) + << " = output_tensor(("; + for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { + if (j) stream << ", "; + PrintExpr(outputs[i]->shape[j], stream); + } + if (outputs[i]->shape.size() == 1) + stream << ", "; + stream << "), '" << outputs[i]->dtype << "')\n"; + } + PrintStmt(stmt); + PrintIndent(); + stream << "return "; + for (size_t i = 0; i < outputs.size(); ++i) { + if (i) stream << ", "; + stream << GetTensorID(outputs[i]->op, outputs[i]->value_index); + } + stream << "\n"; +} + +TVM_REGISTER_GLOBAL("hybrid._Dump") +.set_body([](TVMArgs args, TVMRetValue* rv) { + CodeGenHybrid codegen; + if (args.size() == 4) + codegen.DumpStmt(args[0], args[1], args[2], args[3]); + else + codegen.DumpStmt(args[0], args[1], args[2]); + *rv = codegen.Finish(); + }); +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h new file mode 100644 index 000000000000..cdd6b85b9f9e --- /dev/null +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -0,0 +1,162 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file codegen_hybrid.h + * \brief Common utilities to generated C style code. + */ +#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ +#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace contrib { + +using namespace ir; +/*! + * \brief A base class to generate Hybrid Script. + * + * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3. + * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. + */ +class CodeGenHybrid : + public ExprFunctor, + public StmtFunctor { + public: + /*! + * \brief Dump the given function body to hybrid script. + * \param stmt The function body to be dumped to hybrid script. + * \param inputs Input tensors of this schedule. + * \param outputs Output tensors of this schedule. + * \param name The name of the function. + */ + void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, + const std::string &name = "hybrid_func"); + /*! + * \brief Finalize the compilation and return the code. + * \return The code. + */ + std::string Finish(); + /*! \brief Reserve keywords in avoid of name conflict. */ + void ReserveKeywords(); + /*! + * \brief Print the Stmt n to CodeGenHybrid->stream + * \param n The statement to be printed. + */ + void PrintStmt(const Stmt &n) { + this->VisitStmt(n); + } + /*! + * \brief Print the expression n(or its ssa id if in ssa mode) into os + * \param n The expression to be printed. + * \param os The output stream + */ + void PrintExpr(const Expr &n, std::ostream &os) { + this->VisitExpr(n, os); + } + /*! + * \brief Same as PrintExpr, but simply returns result string + * \param n The expression to be printed. + */ + std::string PrintExpr(const Expr &n) { + std::ostringstream os; + PrintExpr(n, os); + return os.str(); + } + // expression + void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) + // statment + void VisitStmt_(const LetStmt* op) override; + void VisitStmt_(const Store* op) override; + void VisitStmt_(const Provide* op) override; + void VisitStmt_(const For* op) override; + void VisitStmt_(const IfThenElse* op) override; + void VisitStmt_(const Allocate* op) override; + void VisitStmt_(const Realize* op) override; + void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const AssertStmt* op) override; + void VisitStmt_(const Evaluate* op) override; + void VisitStmt_(const Block* op) override; + void VisitStmt_(const ProducerConsumer* op) override; + /*! + * \brief Print Type represetnation of type t. + * \param t The type representation. + * \param os The stream to print the ctype into + */ + virtual void PrintType(Type t, std::ostream& os); // NOLINT(*) + + private: + /*! \brief The current indent of the code dump. */ + int indent_{0}; + /*! \brief The tab size of code indent. */ + const int tab_{4}; + /*! \brief Print the current indent spaces. */ + inline void PrintIndent(); + /*! \brief Keys are ids allocated, and values are the suffix to prevent double-name. */ + std::map ids_allocated_; + /*! + * \brief Keys are either (tensors, value_index) or (variables, 0). + * Values are the corresponding IDs.*/ + std::map, std::string> id_map_; + /*! \brief Variables (keys) binded to the threads (values). */ + std::map binds_; + /*! + * \brief Find an unallocated name for the given prefix. + * \param prefix The given prefix. + */ + std::string GetUniqueName(std::string prefix); + /*! \brief The output code string builder. */ + std::stringstream stream; + /*! + * \brief Get or allocate the ID for the given variable. + * \param v The given variable. + */ + std::string GetVarID(const Variable *v); + /*! + * \brief Get or allocate the ID for the given tensor. + * \param func The tensor to allocate a name. + * \param value_index The value index of the given tensor. + */ + std::string GetTensorID(const FunctionRef &func, int value_index); + /*! \brief the storage scope of allocation */ + std::map alloc_storage_scope_; +}; + +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 26daefa76d7f..0268498c7db2 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -173,25 +173,28 @@ Stmt HybridOpNode::BuildProvide( rmap[outputs[i]] = stage->op.output(i); } auto n = make_node(*this); - /* - * These two lines of codes replace tensors' reads & writes. + /* This is a story little bit complicated. + * The following two lines of codes replace output tensors' usage. * This is the simplest way I (@were) can come up with to glue - * hybrid scripts to the structure of TVM op. - * NAMING CONFLICT: In hybrid script all the tensors have their own - * names specified by the users. However, In TVM op, all the output - * tensors' names are the same as the op's name. I cannot change the - * name to the op's name in the function body after the op node is - * formed, because: - * 1. Output tensors all point to the corresponding op node. - * 2. Once OpNode is wrapped up by an Operation node, it can - * no longer be changed. + * hybrid operation node to TVM op system. + * In hybrid script all the tensors, especially the output tensors, + * have their own names defined by the users. However, In TVM + * conventional ops: + * 1. Output tensors refer the corresponding op node so that the output + * tensors have the same names as the operation produces them. + * 2. Once OpNode is wrapped up by an Operation node, it is finalized. + * Later access will be from a const OpNode*. * This is a chiken-egg paradox. It is impossible to put the output * tensors into the function body without forming the op node. The * function body is immutable after the node is formed. * * Finally, I decided to resolve this issue "lazily". During the - * pipeline of compilation, these tensors will be replaced when - * forming the function body and passing to next stage of compilation. + * pipeline of compilation, this stage is a very preliminary stage. + * Technically, it is before Phase 0. The actual tensors will be replaced + * here. + * Thus, the operation body is slightly different from the Phase 0 body. + * This is a major difference that HybridOpNode is NOT the same as + * ExternOpNode. * */ ret = op::ReplaceTensor(ret, rmap); ret = op::ReplaceProvideTensor(ret, rmap); diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index a54fec3a7bf7..405577b05b3b 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -1,6 +1,7 @@ -import tvm, inspect, sys, traceback, numpy, nose, types +import tvm, inspect, sys, traceback, numpy, nose, types, os +from tvm.contrib import util from tvm.hybrid import script -from tvm.hybrid.intrin import HYBRID_GLOBALS +from tvm.hybrid.runtime import HYBRID_GLOBALS @nose.tools.nottest def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): @@ -59,6 +60,11 @@ def tvm_val_2_py_val(val): for nd, np in zip(out_tensors, ref_data): tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) + module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + h_module = tvm.hybrid.build(sch, module_args, module_outs) + + return h_module, module_args, module_outs @script def outer_product(n, m, a, b): @@ -69,6 +75,7 @@ def outer_product(n, m, a, b): c = output_tensor((n, m), a.dtype) for i in range(n): for j in range(m): + assert i < n and j < m, "index out of range!" c[i, j] = a[i] * b[j] return c @@ -100,6 +107,10 @@ def test_outer_product(): assert ibody.extent.name == 'm' #Check loop body jbody = ibody.body + assert isinstance(jbody, tvm.stmt.AssertStmt) + assert isinstance(jbody.message, tvm.expr.StringImm) + assert jbody.message.value == "index out of range!" + jbody = jbody.body assert isinstance(jbody, tvm.stmt.Provide) assert jbody.func.name == 'c' assert len(jbody.args) == 2 @@ -111,8 +122,13 @@ def test_outer_product(): assert mul.a.name == 'a' assert mul.b.name == 'b' - - run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) + func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) + temp = util.tempdir() + path = temp.relpath('%s.py' % func.name) + func.save(path) + func_ = tvm.hybrid.HybridModule() + func_.load(path) + run_and_check(func_, ins, {n: 99, m: 101}, outs=outs) for key, _ in HYBRID_GLOBALS.items(): assert key not in globals().keys() @@ -197,7 +213,8 @@ def fanout(n, a): assert len(write.value.args) == 1 assert write.value.args[0].value == 0 - run_and_check(fanout, [n, a], {n: 10}) + func, ins, outs = run_and_check(fanout, [n, a], {n: 10}) + run_and_check(func, ins, {n: 10}, outs=outs) def test_looptype(): @@ -229,7 +246,8 @@ def looptype(a, b, c): assert jloop.for_type == tvm.stmt.For.Vectorized assert kloop.for_type == tvm.stmt.For.Unrolled - run_and_check(looptype, [a, b, c]) + func, ins, outs = run_and_check(looptype, [a, b, c]) + run_and_check(func, ins, outs=outs) def test_if(): @@ -248,7 +266,8 @@ def if_then_else(a): a = tvm.placeholder((10, ), dtype='int32', name='a') - run_and_check(if_then_else, [a]) + func, ins, outs = run_and_check(if_then_else, [a]) + run_and_check(func, ins, outs=outs) @script def if_triple_condition(a): @@ -260,7 +279,8 @@ def if_triple_condition(a): b[i] = a[i] + 1 return b - run_and_check(if_triple_condition, [a]) + func, ins, outs = run_and_check(if_triple_condition, [a]) + run_and_check(func, ins, outs=outs) @script def if_and(a): @@ -272,7 +292,8 @@ def if_and(a): b[i] = a[i] + 1 return b - run_and_check(if_and, [a]) + func, ins, outs = run_and_check(if_and, [a]) + run_and_check(func, ins, outs=outs) def test_bind(): @@ -288,7 +309,8 @@ def vec_add(a, b): a = tvm.placeholder((1000, ), dtype='float32', name='a') b = tvm.placeholder((1000, ), dtype='float32', name='b') - run_and_check(vec_add, [a, b], target='cuda') + func, ins, outs = run_and_check(vec_add, [a, b], target='cuda') + run_and_check(func, ins, outs=outs, target='cuda') @script def raw(a, b): @@ -301,7 +323,8 @@ def raw(a, b): sch = tvm.create_schedule(c.op) x = tvm.thread_axis('threadIdx.x') sch[c].bind(c.op.axis[0], x) - run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda') + func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda') + run_and_check(func, ins, outs=outs, target='cuda') # Test loop binds @tvm.hybrid.script @@ -318,7 +341,8 @@ def goo(a, b): b = [1, 2, 3, 4, 5] c = goo(a, tvm.convert(b)) sch = tvm.create_schedule(c.op) - run_and_check(goo, [a, b], sch=sch, outs=[c]) + func, ins, outs = run_and_check(goo, [a, b], sch=sch, outs=[c]) + run_and_check(func, ins, outs=outs) def test_math_intrin(): @script @@ -379,7 +403,8 @@ def blur(a): return b a = tvm.placeholder((32, 32), 'float32', 'a') - run_and_check(blur, [a]) + func, ins, outs = run_and_check(blur, [a]) + run_and_check(func, ins, outs=outs) @tvm.hybrid.script def triangle(a, b): @@ -392,7 +417,8 @@ def triangle(a, b): a = tvm.placeholder((10, ), dtype='float32', name='a') b = tvm.placeholder((10, ), dtype='float32', name='b') - run_and_check(triangle, [a, b]) + func, ins, outs = run_and_check(triangle, [a, b]) + run_and_check(func, ins, outs=outs) def test_allocate(): @tvm.hybrid.script @@ -408,7 +434,10 @@ def blur2d(a): return b a = tvm.placeholder((32, 32), 'float32', 'a') - run_and_check(blur2d, [a]) + b = blur2d(a) + sch = tvm.create_schedule(b.op) + func, ins, outs = run_and_check(blur2d, [a]) + run_and_check(func, ins, outs=outs) if tvm.gpu().exist: @tvm.hybrid.script @@ -426,7 +455,8 @@ def share_vec_add(a, b): a = tvm.placeholder((256, ), dtype='float32', name='a') b = tvm.placeholder((256, ), dtype='float32', name='b') - run_and_check(share_vec_add, [a, b], target='cuda') + func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda') + run_and_check(func, ins, outs=outs, target='cuda') else: print('[Warning] No GPU found! Skip shared mem test!') @@ -562,7 +592,8 @@ def foo(a, b): a = tvm.placeholder((10, ), name='a') b = tvm.placeholder((10, ), name='b') - run_and_check(foo, [a, b]) + func, ins, outs = run_and_check(foo, [a, b]) + run_and_check(func, ins, outs=outs) def test_bool(): @tvm.hybrid.script @@ -576,27 +607,29 @@ def foo(a): b[i] = 0.0 return b a = tvm.placeholder((10, ), name='a') - run_and_check(foo, [a]) + func, ins, outs = run_and_check(foo, [a]) + run_and_check(func, ins, outs=outs) def test_const_range(): @tvm.hybrid.script def foo(a, b): c = output_tensor(a.shape, a.dtype) - d = output_tensor(a.shape, a.dtype) + d = output_tensor(a.shape, 'int32') for i in const_range(2): for j in const_range(5): - c[i, j] = a[i, j] + b[i, j] + c[i, j] = float32(int32(a[i, j]) + b[i, j]) for i in const_range(len(b)): for j in const_range(len(b[0])): - d[i, j] = a[i, j] + b[i, j] + d[i, j] = int32(a[i, j] + b[i, j]) return c, d - a = tvm.placeholder((2, 5), name='a', dtype='int32') + a = tvm.placeholder((2, 5), name='a', dtype='float32') b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]] - run_and_check(foo, [a, b]) + func, ins, outs = run_and_check(foo, [a, b]) + run_and_check(func, ins, outs=outs) @tvm.hybrid.script def goo(a, b): @@ -612,7 +645,8 @@ def goo(a, b): b = [1, 2, 3, 4, 5] c = goo(a, tvm.convert(b)) sch = tvm.create_schedule(c.op) - run_and_check(goo, [a, b]) + func, ins, outs = run_and_check(goo, [a, b]) + run_and_check(func, ins, outs=outs) @tvm.hybrid.script def hoo(a, b): @@ -626,7 +660,8 @@ def hoo(a, b): return c a = tvm.placeholder((5, ), name='a', dtype='int32') b = [1, 2, 3, 4, 5] - run_and_check(hoo, [a, b]) + func, ins, outs = run_and_check(hoo, [a, b]) + run_and_check(func, ins, outs=outs) def test_schedule(): @script @@ -668,7 +703,8 @@ def outer_product(a, b): assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'j.outer.inner' ir = ir.body - run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + run_and_check(func, ins, outs=outs) # Test fuse sch = tvm.create_schedule(c.op) @@ -680,13 +716,15 @@ def outer_product(a, b): ir = ir.body assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'i.j.fused' - run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + run_and_check(func, ins, outs=outs) # Test imperfect loop split sch = tvm.create_schedule(c.op) sch[c].split(c.op.axis[0], 3) ir = tvm.lower(sch, [a, b, c], simple_mode=True) - run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + run_and_check(func, ins, outs=outs) # Test loop binds