Skip to content

Commit

Permalink
[Hybrid Script] Add max_num_threads (apache#2672)
Browse files Browse the repository at this point in the history
* i think it works for now?

* fix lint

* fix 2/3 compat

* fix py2 again

* fine, i gave up
  • Loading branch information
were authored and wweic committed Mar 9, 2019
1 parent f96b27b commit dfffc83
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 38 deletions.
14 changes: 13 additions & 1 deletion python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from ..container import Array
from .. import ir_pass
from ..stmt import For
Expand Down Expand Up @@ -123,11 +124,22 @@ def ceil_div(func_id, args):
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) / b
return (a + b - 1) // b


def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args)


def max_num_threads(func_id, args):
_internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
_internal_assert(args.__len__() <= 1, "At most one argument accepted!")
if args.__len__() == 0:
res = _tgt.current_target().max_num_threads
else:
_internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
return _api.convert(res)
27 changes: 17 additions & 10 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def visit_Expr(self, node):

def visit_Name(self, node):
name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(eval(name)) #pylint: disable=eval-used
ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
Expand Down Expand Up @@ -248,6 +250,10 @@ def visit_Num(self, node):
return _api.const(node.n, dtype)


def visit_NameConstant(self, node):
return _api.convert(node.value)


def visit_AugAssign(self, node):
buf = self.visit(node.target)
rhs = self.visit(node.value)
Expand Down Expand Up @@ -450,17 +456,18 @@ def visit_Call(self, node):

func_id = node.func.id
args = [self.visit(i) for i in node.args]
try:
# Intrinsics'
if hasattr(calls, func_id):
return getattr(calls, func_id)(func_id, args)
except AttributeError:
_internal_assert(func_id in self.symbols.keys(), \
"The function called is not in the context either!")
ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, \
"Are you sure what you call is a function?!")
outs = entry(*args)
op = outs.op if isinstance(outs, Tensor) else outs[0].op
return op
# Contexts'
_internal_assert(func_id in self.symbols.keys(), \
"The function called (%s) is not in the context either!" % func_id)
ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, \
"Are you sure what you call is a function?!")
outs = entry(*args)
op = outs.op if isinstance(outs, Tensor) else outs[0].op
return op


def visit_For(self, node):
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/hybrid/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def visit_AugAssign(self, node):


def visit_Name(self, node):
# If it is True or False, we do not worry about it!
if sys.version_info[0] == 2 and node.id in ['True', 'False']:
return
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
return
Expand Down
61 changes: 34 additions & 27 deletions python/tvm/hybrid/runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""

import numpy
from .. import target


class bind(object): #pylint: disable=invalid-name
Expand Down Expand Up @@ -72,34 +73,40 @@ def sigmoid(x):
return 1 / (1 + numpy.exp(-x))


def max_num_threads(allow_none=True):
"""Get max number of threads for GPU targets."""
return target.current_target(allow_none).max_num_threads


HYBRID_GLOBALS = {
'unroll' : range,
'vectorize' : range,
'parallel' : range,
'const_range' : range,
'bind' : bind,
'allocate' : allocate,
'output_tensor': allocate,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) / b
'unroll' : range,
'vectorize' : range,
'parallel' : range,
'const_range' : range,
'bind' : bind,
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) // b,
'max_num_threads': max_num_threads
}


Expand Down
3 changes: 3 additions & 0 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("for");
GetUniqueName("in");
GetUniqueName("range");
GetUniqueName("True");
GetUniqueName("False");
GetUniqueName("unroll");
GetUniqueName("const_range");
GetUniqueName("parallel");
Expand Down Expand Up @@ -434,6 +436,7 @@ void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("float32");
GetUniqueName("float64");
GetUniqueName("ceil_div");
GetUniqueName("max_num_threads");
}

void CodeGenHybrid::DumpStmt(const Stmt &stmt,
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_hybrid_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,22 @@ def foo(a):
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')

@tvm.hybrid.script
def max_threads(a):
b = output_tensor(a.shape, a.dtype)
n = a.shape[0]
m = max_num_threads(True)
for i in bind('threadIdx.x', m):
for j in bind('blockIdx.x', ceil_div(n, m)):
if i * m + j < n:
b[i * m + j] = a[i * m + j] + a[i * m + j]
return b

a = tvm.placeholder((10000, ), 'float32')
with tvm.target.create('cuda'):
func, ins, outs = run_and_check(max_threads, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')


def test_math_intrin():
@script
Expand Down

0 comments on commit dfffc83

Please sign in to comment.