Skip to content

Commit

Permalink
Merge branch 'master' into ansor-gpu-tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 18, 2020
2 parents e602929 + de0c3a4 commit 2b3ef79
Show file tree
Hide file tree
Showing 11 changed files with 1,208 additions and 1,230 deletions.
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def init_utvm(args):
args : argparse.Namespace
parsed args from command-line invocation
"""
from tvm import micro
from tvm import micro # pylint: disable=import-outside-toplevel

if args.utvm_dev_config and args.utvm_dev_id:
raise RuntimeError("only one of --utvm-dev-config and --utvm-dev-id allowed")
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/hybrid/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@
"""FFI APIs for tvm.hybrid"""
import tvm._ffi


tvm._ffi._init_api("tir.hybrid", __name__)
tvm._ffi._init_api("hybrid", __name__)
108 changes: 70 additions & 38 deletions python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,114 +23,146 @@
from .registry import register_intrin


@register_intrin
@register_intrin()
def bool(imm):
return tvm.tir.const(imm.value, "bool")
return tvm.tir.const(imm, "bool")


@register_intrin
@register_intrin()
def int8(imm):
return tvm.tir.const(imm.value, "int8")
return tvm.tir.const(imm, "int8")


@register_intrin
@register_intrin()
def int16(imm):
return tvm.tir.const(imm.value, "int16")
return tvm.tir.const(imm, "int16")


@register_intrin
@register_intrin()
def int32(imm):
return tvm.tir.const(imm.value, "int32")
return tvm.tir.const(imm, "int32")


@register_intrin
@register_intrin()
def int64(imm):
return tvm.tir.const(imm.value, "int64")
return tvm.tir.const(imm, "int64")


@register_intrin
@register_intrin()
def uint8(imm):
return tvm.tir.const(imm.value, "uint8")
return tvm.tir.const(imm, "uint8")


@register_intrin
@register_intrin()
def uint16(imm):
return tvm.tir.const(imm.value, "uint16")
return tvm.tir.const(imm, "uint16")


@register_intrin
@register_intrin()
def uint32(imm):
return tvm.tir.const(imm.value, "uint32")
return tvm.tir.const(imm, "uint32")


@register_intrin
@register_intrin()
def uint64(imm):
return tvm.tir.const(imm.value, "uint64")
return tvm.tir.const(imm, "uint64")


@register_intrin
@register_intrin()
def float8(imm):
return tvm.tir.const(imm.value, "float8")
return tvm.tir.const(imm, "float8")


@register_intrin
@register_intrin()
def float16(imm):
return tvm.tir.const(imm.value, "float16")
return tvm.tir.const(imm, "float16")


@register_intrin
@register_intrin()
def float32(imm):
return tvm.tir.const(imm.value, "float32")
return tvm.tir.const(imm, "float32")


@register_intrin
@register_intrin()
def float64(imm):
return tvm.tir.const(imm.value, "float64")
return tvm.tir.const(imm, "float64")


@register_intrin
@register_intrin()
def floordiv(x, y):
return tvm.tir.floordiv(x, y)


@register_intrin
@register_intrin()
def floormod(x, y):
return tvm.tir.floormod(x, y)


@register_intrin
@register_intrin()
def load(dtype, var, index, predicate=True):
return tvm.tir.Load(dtype, var, index, predicate)


@register_intrin
def cast(dtype, value):
@register_intrin()
def cast(value, dtype):
return tvm.tir.Cast(dtype, value)


@register_intrin
@register_intrin()
def ramp(base, stride, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Ramp(base, stride, lanes)


@register_intrin
@register_intrin()
def broadcast(value, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Broadcast(value, lanes)


@register_intrin
@register_intrin()
def evaluate(value):
return tvm.tir.Evaluate(value)


@register_intrin
@register_intrin()
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)


@register_intrin
@register_intrin()
def iter_var(var, dom, iter_type, thread_tag):
iter_type = getattr(tvm.tir.IterVar, iter_type)
return tvm.tir.IterVar(dom, var, iter_type, thread_tag)


@register_intrin()
def max(a, b): # pylint: disable=redefined-builtin
return tvm.tir.Max(a, b)


def get_axis(begin, end, iter_type):
ana = tvm.arith.Analyzer()
extent = ana.simplify(end - begin)
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type])


@register_intrin()
def range(begin, end):
return get_axis(begin, end, "data_par")


@register_intrin()
def reduce_axis(begin, end):
return get_axis(begin, end, "reduce")


@register_intrin()
def scan_axis(begin, end):
return get_axis(begin, end, "scan")


@register_intrin()
def opaque_axis(begin, end):
return get_axis(begin, end, "opaque")
Loading

0 comments on commit 2b3ef79

Please sign in to comment.