From 11b1ff8e0f94f7396d93051532a716919bc36937 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 1 Jun 2022 17:53:02 -0700 Subject: [PATCH] python methods (#38) * ir builder in python * `ForFrame`s in python * python methods * rename and add @staticmethod for current builder * POC demo in python * apply code review suggestions * apply code review suggestions * apply code review suggestions * apply code review suggestions --- python/tvm/script/__init__.py | 1 + python/tvm/script/builder/__init__.py | 22 +++++++ python/tvm/script/builder/_ffi_api.py | 20 +++++++ python/tvm/script/builder/builder.py | 58 +++++++++++++++++++ python/tvm/script/builder/frame.py | 38 ++++++++++++ python/tvm/script/builder/tir/__init__.py | 33 +++++++++++ python/tvm/script/builder/tir/_ffi_api.py | 22 +++++++ python/tvm/script/builder/tir/axis.py | 37 ++++++++++++ python/tvm/script/builder/tir/base.py | 26 +++++++++ python/tvm/script/builder/tir/block_frame.py | 31 ++++++++++ python/tvm/script/builder/tir/for_frame.py | 56 ++++++++++++++++++ .../tvm/script/builder/tir/prim_func_frame.py | 40 +++++++++++++ python/tvm/script/builder/tir/var.py | 26 +++++++++ src/script/builder/builder.cc | 10 ++++ src/script/builder/frame.cc | 4 ++ src/script/builder/frame.h | 2 +- src/script/builder/tir/block_frame.cc | 10 ++++ src/script/builder/tir/for_frame.cc | 14 +++++ src/script/builder/tir/prim_func_frame.cc | 14 +++++ src/script/builder/tir/var.cc | 2 + .../python/unittest/test_tvmscript_builder.py | 40 +++++++++++++ 21 files changed, 505 insertions(+), 1 deletion(-) create mode 100644 python/tvm/script/builder/__init__.py create mode 100644 python/tvm/script/builder/_ffi_api.py create mode 100644 python/tvm/script/builder/builder.py create mode 100644 python/tvm/script/builder/frame.py create mode 100644 python/tvm/script/builder/tir/__init__.py create mode 100644 python/tvm/script/builder/tir/_ffi_api.py create mode 100644 python/tvm/script/builder/tir/axis.py create mode 100644 python/tvm/script/builder/tir/base.py create mode 100644 python/tvm/script/builder/tir/block_frame.py create mode 100644 python/tvm/script/builder/tir/for_frame.py create mode 100644 python/tvm/script/builder/tir/prim_func_frame.py create mode 100644 python/tvm/script/builder/tir/var.py create mode 100644 tests/python/unittest/test_tvmscript_builder.py diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 62279b46c18c..a6b24f46c317 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -19,4 +19,5 @@ from . import tir from . import relax +from .builder import Builder from .parser import ir_module, from_source diff --git a/python/tvm/script/builder/__init__.py b/python/tvm/script/builder/__init__.py new file mode 100644 index 000000000000..999bfb1b6930 --- /dev/null +++ b/python/tvm/script/builder/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Namespace for the TVMScript Builder API.""" + + +from .builder import Builder, def_, def_many +from .frame import Frame, IRModuleFrame diff --git a/python/tvm/script/builder/_ffi_api.py b/python/tvm/script/builder/_ffi_api.py new file mode 100644 index 000000000000..ec20ad798f80 --- /dev/null +++ b/python/tvm/script/builder/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.builder""" +import tvm._ffi + +tvm._ffi._init_api("script.builder", __name__) diff --git a/python/tvm/script/builder/builder.py b/python/tvm/script/builder/builder.py new file mode 100644 index 000000000000..3d449ef1975a --- /dev/null +++ b/python/tvm/script/builder/builder.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script IR Builder""" +from typing import List +from tvm._ffi import register_object as _register_object +from .frame import Frame + +from tvm.runtime import Object + +from . import _ffi_api + +from typing import TypeVar + + +@_register_object("script.builder.Builder") +class Builder(Object): + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.Builder) + + def __enter__(self) -> "Builder": + _ffi_api.BuilderEnter(self) + return self + + def __exit__(self, ptype, value, trace) -> None: + _ffi_api.BuilderExit(self) + + @staticmethod + def current(self) -> "Builder": + return _ffi_api.BuilderCurrent(self) + + def get(self) -> Frame: + return _ffi_api.BuilderGet(self) + + +DefType = TypeVar("DefType", bound=Object) + + +def def_(name: str, var: DefType) -> DefType: + return _ffi_api.Def(name, var) + + +def def_many(names: List[str], vars: List[DefType]) -> List[DefType]: + assert len(names) == len(vars) + return [def_(name, var) for name, var in zip(names, vars)] diff --git a/python/tvm/script/builder/frame.py b/python/tvm/script/builder/frame.py new file mode 100644 index 000000000000..7f6ac8972fcf --- /dev/null +++ b/python/tvm/script/builder/frame.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script Frames""" +from tvm._ffi import register_object as _register_object + +from tvm.runtime import Object + +from . import _ffi_api + + +@_register_object("script.builder.Frame") +class Frame(Object): + def __enter__(self) -> "Frame": + _ffi_api.FrameEnter(self) + return self + + def __exit__(self, ptype, value, trace) -> None: + _ffi_api.FrameExit(self) + + +@_register_object("script.builder.IRModuleFrame") +class IRModuleFrame(Frame): + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.IRModuleFrame) diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py new file mode 100644 index 000000000000..c206e5e84059 --- /dev/null +++ b/python/tvm/script/builder/tir/__init__.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Namespace for the TVMScript TIR Builder API.""" + +from .base import TIRFrame +from .for_frame import ( + ForFrame, + serial, + parallel, + vectorized, + unroll, + thread_binding, + grid, +) +from .prim_func_frame import prim_func, arg +from .block_frame import block +from .var import Buffer +from . import axis diff --git a/python/tvm/script/builder/tir/_ffi_api.py b/python/tvm/script/builder/tir/_ffi_api.py new file mode 100644 index 000000000000..df97ad7ae7f2 --- /dev/null +++ b/python/tvm/script/builder/tir/_ffi_api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.builder""" +import tvm._ffi + +from .. import _ffi_api as _base_ffi_api + +tvm._ffi._init_api("script.builder.tir", __name__) diff --git a/python/tvm/script/builder/tir/axis.py b/python/tvm/script/builder/tir/axis.py new file mode 100644 index 000000000000..9bb3e75650b5 --- /dev/null +++ b/python/tvm/script/builder/tir/axis.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Axis""" + +from . import _ffi_api +from tvm.ir import Range +from tvm.tir import IterVar + + +def spatial(dom, binding, dtype="int32") -> IterVar: + if not isinstance(dom, Range): + dom = Range(0, dom) + return _ffi_api.AxisSpatial(dom, binding, dtype) + + +def reduce(dom, binding, dtype="int32") -> IterVar: + if not isinstance(dom, Range): + dom = Range(0, dom) + return _ffi_api.AxisReduce(dom, binding, dtype) + + +def remap(kinds, bindings, dtype="int32") -> IterVar: + return _ffi_api.AxisRemap(kinds, bindings, dtype) diff --git a/python/tvm/script/builder/tir/base.py b/python/tvm/script/builder/tir/base.py new file mode 100644 index 000000000000..e19c47b0a478 --- /dev/null +++ b/python/tvm/script/builder/tir/base.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Frame""" +from tvm._ffi import register_object as _register_object + +from . import _ffi_api +from ..frame import Frame + + +@_register_object("script.builder.tir.TIRFrame") +class TIRFrame(Frame): + pass diff --git a/python/tvm/script/builder/tir/block_frame.py b/python/tvm/script/builder/tir/block_frame.py new file mode 100644 index 000000000000..c90b0fac87c8 --- /dev/null +++ b/python/tvm/script/builder/tir/block_frame.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Block Frame""" +from tvm._ffi import register_object as _register_object +from .base import TIRFrame + + +from . import _ffi_api + + +@_register_object("script.builder.tir.BlockFrame") +class BlockFrame(TIRFrame): + pass + + +def block(name) -> BlockFrame: + return _ffi_api.BlockFrame(name) diff --git a/python/tvm/script/builder/tir/for_frame.py b/python/tvm/script/builder/tir/for_frame.py new file mode 100644 index 000000000000..13b2599ae233 --- /dev/null +++ b/python/tvm/script/builder/tir/for_frame.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR For Frame""" +from tvm._ffi import register_object as _register_object + +from tvm.tir import Var + +from . import _ffi_api +from ._ffi_api import _base_ffi_api +from .base import TIRFrame +from typing import List + + +@_register_object("script.builder.tir.ForFrame") +class ForFrame(TIRFrame): + def __enter__(self) -> List[Var]: + _base_ffi_api.FrameEnter(self) + return self.vars + + +def serial(min_val, extent, attrs) -> ForFrame: + return _ffi_api.Serial(min_val, extent, attrs) + + +def parallel(min_val, extent, attrs) -> ForFrame: + return _ffi_api.Parallel(min_val, extent, attrs) + + +def vectorized(min_val, extent, attrs) -> ForFrame: + return _ffi_api.Vectorized(min_val, extent, attrs) + + +def unroll(min_val, extent, attrs) -> ForFrame: + return _ffi_api.Unroll(min_val, extent, attrs) + + +def thread_binding(min_val, extent, attrs) -> ForFrame: + return _ffi_api.ThreadBinding(min_val, extent, attrs) + + +def grid(*extents) -> ForFrame: + return _ffi_api.Grid(extents) diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py new file mode 100644 index 000000000000..4a223af55f19 --- /dev/null +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Prim Func Frame""" +from tvm._ffi import register_object as _register_object + +from tvm.tir.expr import Var +from tvm.tir.buffer import Buffer + + +from . import _ffi_api +from .base import TIRFrame + +from typing import Union + + +@_register_object("script.builder.tir.PrimFuncFrame") +class PrimFuncFrame(TIRFrame): + pass + + +def prim_func(name) -> PrimFuncFrame: + return _ffi_api.PrimFuncFrame(name) + + +def arg(name, arg) -> Union[Var, Buffer]: + return _ffi_api.Arg(name, arg) diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py new file mode 100644 index 000000000000..fa06ee63c14a --- /dev/null +++ b/python/tvm/script/builder/tir/var.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Buffer""" +from tvm._ffi import register_object as _register_object + +from tvm.tir.buffer import Buffer + +from . import _ffi_api + + +def Buffer(shape, dtype, name="buffer", storage_scope="") -> Buffer: + return _ffi_api.Buffer(shape, dtype, name, storage_scope) diff --git a/src/script/builder/builder.cc b/src/script/builder/builder.cc index f13f90f6953c..b9c5b9848608 100644 --- a/src/script/builder/builder.cc +++ b/src/script/builder/builder.cc @@ -18,6 +18,8 @@ */ #include "./builder.h" +#include + namespace tvm { namespace script { namespace builder { @@ -80,6 +82,14 @@ ObjectRef DefImpl(String name, ObjectRef obj) { TVM_REGISTER_NODE_TYPE(BuilderNode); +TVM_REGISTER_GLOBAL("script.builder.Builder").set_body_typed([]() { return Builder(); }); +TVM_REGISTER_GLOBAL("script.builder.BuilderEnter").set_body_method(&Builder::EnterWithScope); +TVM_REGISTER_GLOBAL("script.builder.BuilderExit").set_body_method(&Builder::ExitWithScope); +TVM_REGISTER_GLOBAL("script.builder.BuilderCurrent").set_body_typed(Builder::Current); +TVM_REGISTER_GLOBAL("script.builder.BuilderGet") + .set_body_method(&BuilderNode::Get); +TVM_REGISTER_GLOBAL("script.builder.Def").set_body_typed(Def); + } // namespace builder } // namespace script } // namespace tvm diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc index 9359868ef0e6..56280a0b5ec5 100644 --- a/src/script/builder/frame.cc +++ b/src/script/builder/frame.cc @@ -61,6 +61,10 @@ void IRModuleFrameNode::ExitWithScope() { TVM_REGISTER_NODE_TYPE(FrameNode); TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); +TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method(&FrameNode::EnterWithScope); + +TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method(&FrameNode::ExitWithScope); + } // namespace builder } // namespace script } // namespace tvm diff --git a/src/script/builder/frame.h b/src/script/builder/frame.h index 0f86f326dafe..e3465a9e30b8 100644 --- a/src/script/builder/frame.h +++ b/src/script/builder/frame.h @@ -46,7 +46,7 @@ class FrameNode : public runtime::Object { class Frame : public runtime::ObjectRef { public: virtual ~Frame() = default; - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); protected: Frame() = default; diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index 379bc0dd113b..d892df8a1b90 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -18,6 +18,8 @@ */ #include "./block_frame.h" +#include + #include "./for_frame.h" namespace tvm { @@ -144,6 +146,14 @@ Array Remap(String kinds, Array bindings, DataType TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_); + +TVM_REGISTER_GLOBAL("script.builder.tir.AxisSpatial").set_body_typed(axis::Spatial); + +TVM_REGISTER_GLOBAL("script.builder.tir.AxisReduce").set_body_typed(axis::Reduce); + +TVM_REGISTER_GLOBAL("script.builder.tir.AxisRemap").set_body_typed(axis::Remap); + } // namespace tir } // namespace builder } // namespace script diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc index f22d818cc673..f242453395ee 100644 --- a/src/script/builder/tir/for_frame.cc +++ b/src/script/builder/tir/for_frame.cc @@ -18,6 +18,8 @@ */ #include "./for_frame.h" +#include + namespace tvm { namespace script { namespace builder { @@ -92,6 +94,18 @@ ForFrame Grid(Array extents) { TVM_REGISTER_NODE_TYPE(ForFrameNode); +TVM_REGISTER_GLOBAL("script.builder.tir.Serial").set_body_typed(Serial); + +TVM_REGISTER_GLOBAL("script.builder.tir.Parallel").set_body_typed(Parallel); + +TVM_REGISTER_GLOBAL("script.builder.tir.Vectorized").set_body_typed(Vectorized); + +TVM_REGISTER_GLOBAL("script.builder.tir.Unroll").set_body_typed(Unroll); + +TVM_REGISTER_GLOBAL("script.builder.tir.ThreadBinding").set_body_typed(ThreadBinding); + +TVM_REGISTER_GLOBAL("script.builder.tir.Grid").set_body_typed(Grid); + } // namespace tir } // namespace builder } // namespace script diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 039a6ecdef56..70ba8c9adae7 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -74,6 +74,20 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); +TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc_); + +TVM_REGISTER_GLOBAL("script.builder.tir.Arg") + .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { + using namespace tvm::tir; + if (const auto* var = obj.as()) { + return Arg(name, GetRef(var)); + } else if (const auto* buffer = obj.as()) { + return Arg(name, GetRef(buffer)); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR Arg."; + } + }); + } // namespace tir } // namespace builder } // namespace script diff --git a/src/script/builder/tir/var.cc b/src/script/builder/tir/var.cc index 01ea3a01aad8..e3d9c2367b66 100644 --- a/src/script/builder/tir/var.cc +++ b/src/script/builder/tir/var.cc @@ -63,6 +63,8 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); +TVM_REGISTER_GLOBAL("script.builder.tir.Buffer").set_body_typed(Buffer_); + } // namespace tir } // namespace builder } // namespace script diff --git a/tests/python/unittest/test_tvmscript_builder.py b/tests/python/unittest/test_tvmscript_builder.py new file mode 100644 index 000000000000..7fb4cc1a0eed --- /dev/null +++ b/tests/python/unittest/test_tvmscript_builder.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script.builder import Builder, def_, def_many +from tvm.script.builder import tir as T + + +def test_builder_basic(): + b = Builder() + with b: + with T.prim_func(name="main"): + A = T.arg("A", T.Buffer((128, 128, 128), "float32")) + B = T.arg("B", T.Buffer((128, 128, 128), "float32")) + with T.grid(128, 128, 128) as (i, j, k): + def_many(["i", "j", "k"], [i, j, k]) + with T.block(name="block"): + vi = def_("vi", T.axis.spatial(128, i)) + vj = def_("vj", T.axis.spatial(128, j)) + vk = def_("vk", T.axis.reduce(128, k)) + print(b.get().script()) + tvm._ffi.get_global_func("test_poc")() + + +if __name__ == "__main__": + test_builder_basic()