Skip to content

Commit

Permalink
[TIR] Enable Host Func Attribute for PrimFunc (apache#14020)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh authored and yongwww committed Feb 27, 2023
1 parent afb444c commit caf7aa1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func";
*/
constexpr const char* kIsGlobalFunc = "tir.is_global_func";

/*!
* \brief Mark the function as run on the host, mutually exclusive with kTarget.
*
* Type: Integer
*/
constexpr const char* kIsHostFunc = "tir.is_host_func";

} // namespace attr
} // namespace tir
} // namespace tvm
Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/primfunc_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ namespace tir {
namespace transform {
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)),
tvm::attr::kTarget, target->host.value_or(Target("llvm")));
}
return WithAttr(std::move(f), tvm::attr::kTarget, target);
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
Expand Down
79 changes: 79 additions & 0 deletions tests/python/unittest/test_tir_host_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.meta_schedule.testing import te_workload

# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
# fmt: off

@I.ir_module
class Module:
@T.prim_func
def main(
A: T.Buffer((729, 729), "float32"),
B: T.Buffer((729, 729), "float32"),
C: T.Buffer((729, 729), "float32"),
):
T.func_attr(
{
"global_symbol": "test",
"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
"tir.noalias": True,
}
)
# with T.block("root"):
for i, j, k in T.grid(729, 729, 729):
with T.block("C"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
with T.init():
C[v_i, v_j] = T.float32(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring


def test_host_func():
"""Test that host functions are not split."""
# te schedule copied from test_tir_transform_split_host_device.py

func = tvm.te.create_prim_func(
te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32")
)
mod = tvm.ir.IRModule({"main": func})
target = tvm.target.Target("cuda")
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
"global_symbol": "test",
"tir.is_host_func": 1,
}
)
)(mod)
mod = tvm.tir.transform.BindTarget(target)(mod)
tvm.ir.assert_structural_equal(mod, Module)
assert (
"tir.is_host_func" not in mod["main"].attrs
), """Target and is_host_func attributes should be mutually exclusive"""


if __name__ == "__main__":
test_host_func()

0 comments on commit caf7aa1

Please sign in to comment.