Skip to content

Commit

Permalink
[TIR] Bugfix for zero number arguments tir functions. (apache#8515)
Browse files Browse the repository at this point in the history
* [TIR] Bugfix for zero number arguments tir functions.


Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 93ede50 commit a206d41
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
5 changes: 3 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,15 @@ def LowerCustomDatatypes():
return _ffi_api.LowerCustomDatatypes() # type: ignore


def MakePackedAPI(num_unpacked_params: int = 0):
def MakePackedAPI(num_unpacked_params: int = -1):
"""Transform the PrimFuncs in the module to a packed func API.
Parameters
----------
num_unpacked_params : int
Number of parameters that we hope to directly pass via normal arguments
following the PackedFunc input signature.
following the PackedFunc input signature. If it is specified as -1 or it
is less than the number of arguments, the pass will packed arguments still.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
}

mixed_pass_list.push_back(tir::transform::SplitHostDevice());
Expand Down
15 changes: 10 additions & 5 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
ICHECK_LE(num_unpacked_args, num_args);

bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
if (num_unpacked_args == -1) {
// reset to zero
num_unpacked_args = 0;
}
ICHECK_GE(num_unpacked_args, 0);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
Expand Down Expand Up @@ -154,11 +159,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
}
return res;
};

// ---------------------------
// start of logics
// add signiture for packed arguments.
if (num_packed_args != 0) {
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
Expand Down Expand Up @@ -214,13 +218,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
}

// allow return value if the function is packed.
if (num_packed_args != 0) {
if (pack_args) {
args.push_back(v_out_ret_value);
args.push_back(v_out_ret_tcode);
args.push_back(v_resource_handle);
}

size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 6 : 0);
size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
ICHECK_EQ(args.size(), expected_nargs);

// Arg definitions are defined before buffer binding to avoid the use before
Expand Down Expand Up @@ -282,6 +286,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
namespace transform {

Pass MakePackedAPI(int num_unpacked_args) {
// packed arguments anyway while `num_unpacked_args` is -1
auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_tir_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def test_scalar_add():
assert out == 3.0


def test_ret_const():
a = tir.const(0)
b = tir.ret(a)
b = tir.Evaluate(b)
func = tir.PrimFunc([], b)
func = build_tir_func(func)
out = func()
assert out == 0


def test_control_flow_jump():
ib = tvm.tir.ir_builder.create()
a = tir.Var("a", "float32")
Expand All @@ -57,4 +67,5 @@ def test_control_flow_jump():

if __name__ == "__main__":
test_scalar_add()
test_ret_const()
test_control_flow_jump()

0 comments on commit a206d41

Please sign in to comment.