Skip to content

Commit

Permalink
[TIR] Fix perf regression of tir refactor (apache#5258)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and Trevor Morris committed Apr 16, 2020
1 parent 4f92cb2 commit c99aaa7
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def lower(sch,
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
if cfg.restricted_func:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI()(mod)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)

Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ IRModule lower(te::Schedule sch,
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));

if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.no_alias", Integer(1));
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
return tir::transform::MakePackedAPI(0)(mod);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_target_codegen_static_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_target_codegen_vm_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_makeapi():

num_unpacked_args = 2
f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr(
"tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
"tir.noalias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
mod = tvm.IRModule.from_expr(f)
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7)
Expand Down

0 comments on commit c99aaa7

Please sign in to comment.