Skip to content

Commit

Permalink
introduce blockbuilder call_te (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored and junrushao committed Feb 5, 2023
1 parent 68f1480 commit 1a70cbe
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 46 deletions.
119 changes: 73 additions & 46 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,78 @@ def emit(self, expr: Expr) -> Var:
"""
return _ffi_api.BlockBuilderEmit(self, expr)

def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
"""Generate a call node according to the te function.
This function converts arguments from relax expression to te tensor,
The callback func should return a te tensor or a list of te tensors.
Please see detailed example in emit_te
Parameters
----------
func : Callable
A function that returns a te tensor or a list of te tensors.
args : Any, optional
arguments passed to the function.
kwargs : Any, optional
The keyword arguments passed to the function.
Note that the key "primfunc_name_hint" is reserved for passing name hint
to the PrimFunc that gets generated.
Returns
-------
ret : tvm.relax.Call
A newly created call node
"""

primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
new_args, te_arg_list = self._convert_te_arg(args)
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)

te_args = te_arg_list + te_kwarg_list

te_out = func(*new_args, **new_kwargs)
assert isinstance(te_out, tvm.te.tensor.Tensor) or (
isinstance(te_out, (tuple, list, tvm.ir.Array))
and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
), "only support te.tensor or tuple/list/Array of te.tensor as function output"

if isinstance(te_out, (tuple, list, tvm.ir.Array)) and len(te_out) == 1:
te_out = te_out[0]

outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out)
unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)

inputs = [*te_args] + outs
tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars)

if primfunc_name_hint:
gvar = self.add_func(tir_func, primfunc_name_hint)
else:
gvar = self.add_func(tir_func, func.__name__)

call_args = [x.op.value for x in te_args]

output_shape = (
outs[0].shape
if isinstance(te_out, tvm.te.tensor.Tensor)
else Tuple([ShapeExpr(x.shape) for x in outs])
)

output_dtype = (
te_out.dtype if isinstance(te_out, tvm.te.tensor.Tensor) else [x.dtype for x in outs]
)

# add arguments for extra parameters from unbound var
if len(unbound_tir_vars) > 0:
call = call_tir(
gvar, call_args, output_shape, output_dtype, tir_vars=ShapeExpr(unbound_tir_vars)
)
else:
call = call_tir(gvar, call_args, output_shape, output_dtype)
return call

def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
"""Emit a call node according to the te function.
This function converts arguments from relax expression to te tensor,
Expand Down Expand Up @@ -414,52 +486,7 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"])
gv = relax.call_tir(te_func, (y,), ((n + 1),), (n,), dtype="float32")
return gv
"""
primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
new_args, te_arg_list = self._convert_te_arg(args)
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)

te_args = te_arg_list + te_kwarg_list

te_out = func(*new_args, **new_kwargs)
assert isinstance(te_out, tvm.te.tensor.Tensor) or (
isinstance(te_out, (tuple, list, tvm.ir.Array))
and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
), "only support te.tensor or tuple/list/Array of te.tensor as function output"

if isinstance(te_out, (tuple, list, tvm.ir.Array)) and len(te_out) == 1:
te_out = te_out[0]

outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out)
unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)

inputs = [*te_args] + outs
tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars)

if primfunc_name_hint:
gvar = self.add_func(tir_func, primfunc_name_hint)
else:
gvar = self.add_func(tir_func, func.__name__)

call_args = [x.op.value for x in te_args]

output_shape = (
outs[0].shape
if isinstance(te_out, tvm.te.tensor.Tensor)
else Tuple([ShapeExpr(x.shape) for x in outs])
)

output_dtype = (
te_out.dtype if isinstance(te_out, tvm.te.tensor.Tensor) else [x.dtype for x in outs]
)

# add arguments for extra parameters from unbound var
if len(unbound_tir_vars) > 0:
call = call_tir(
gvar, call_args, output_shape, output_dtype, tir_vars=ShapeExpr(unbound_tir_vars)
)
else:
call = call_tir(gvar, call_args, output_shape, output_dtype)
return self.emit(call)
return self.emit(self.call_te(func, *args, **kwargs))

def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
"""Emit a MatchShape.
Expand Down
32 changes: 32 additions & 0 deletions tests/python/relax/test_blockbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,38 @@ def test_normalize():
assert add_call.shape[1] == n


def test_call_te():
bb = rx.BlockBuilder()
dtype = rx.DynTensorType(rank=2, dtype="float32")
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", [n, m], dtype)
y = rx.Var("y", [n, m], dtype)
z = rx.Var("z", [n, m], dtype)

def te_func(args, args_dict, msg):
A, B = args
C = args_dict["C"]
D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j])
return E

with bb.function("rx_func", [x, y, z]):
with bb.dataflow():
out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello"))
bb.emit_func_output(out)

mod = bb.get()
rx_func = mod["rx_func"]

assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert rx_func.params[2] == z
assert rx_func.name.name_hint == "rx_func"
assert rx_func.body.body == out
assert len(rx_func.body.blocks) == 1
assert len(rx_func.body.blocks[0].bindings) == 1


def test_emit_te():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
Expand Down

0 comments on commit 1a70cbe

Please sign in to comment.