diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 38fe9009dd61..aa2386e7f1e4 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -435,6 +435,313 @@ class RealizeFrame : public TIRFrame { public: TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); }; + +/*! + * \brief A frame represents the allocate. + * + * \sa AllocateFrame + */ +class AllocateFrameNode : public TIRFrameNode { + public: + /*! \brief The extents of the allocate. */ + Array extents; + /*! \brief The data type of the buffer. */ + DataType dtype; + /*! \brief The storage scope. */ + String storage_scope; + /*! \brief The condition. */ + PrimExpr condition; + /*! \brief Additional annotation hints. */ + Map annotations; + /*! \brief The buffer. */ + tvm::tir::Buffer buffer; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("extents", &extents); + v->Visit("dtype", &dtype); + v->Visit("storage_scope", &storage_scope); + v->Visit("condition", &condition); + v->Visit("annotations", &annotations); + v->Visit("buffer", &buffer); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to AllocateFrameNode. + * + * \sa AllocateFrameNode + */ +class AllocateFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); +}; + +/*! + * \brief A frame represents the allocate constant. + * + * \sa AllocateConstFrame + */ +class AllocateConstFrameNode : public TIRFrameNode { + public: + /*! \brief The data type of the buffer. */ + DataType dtype; + /*! \brief The extents of the allocate. */ + Array extents; + /*! \brief The data associated with the constant. */ + tvm::runtime::NDArray data; + /*! \brief The buffer */ + tvm::tir::Buffer buffer; + /*! \brief Additional annotations about the allocation. */ + Map annotations; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("dtype", &dtype); + v->Visit("extents", &extents); + v->Visit("data", &data); + v->Visit("buffer", &buffer); + v->Visit("annotations", &annotations); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to AllocateConstFrameNode. + * + * \sa AllocateConstFrameNode + */ +class AllocateConstFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, + AllocateConstFrameNode); +}; +/*! + * \brief A frame that represents attribute node. + * + * \sa AttrFrame + */ +class AttrFrameNode : public TIRFrameNode { + public: + /*! \brief The node to annotate the attribute. */ + ObjectRef node; + /*! \brief Attribute type key. */ + String attr_key; + /*! \brief The value of the attribute. */ + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("node", &node); + v->Visit("attr_key", &attr_key); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to AttrFrameNode. + * + * \sa AttrFrameNode + */ +class AttrFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); +}; + +/*! + * \brief A frame that represents while loop. + * + * \sa WhileFrame + */ +class WhileFrameNode : public TIRFrameNode { + public: + /*! \brief The termination condition of while. */ + PrimExpr condition; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to WhileFrameNode. + * + * \sa WhileFrameNode + */ +class WhileFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); +}; + +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ +class IfFrameNode : public TIRFrameNode { + public: + /*! \brief The condition of the if statement. */ + PrimExpr condition; + /*! \brief The statements in the true branch. */ + Optional> then_stmts; + /*! \brief The stetements in the false branch. */ + Optional> else_stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_stmts", &then_stmts); + v->Visit("else_stmts", &else_stmts); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ +class IfFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); +}; + +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ +class ThenFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ +class ThenFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); +}; + +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ +class ElseFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ +class ElseFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); +}; + +class DeclBufferFrameNode : public TIRFrameNode { + public: + tvm::tir::Buffer buffer; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("buffer", &buffer); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; +}; + +class DeclBufferFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode); +}; + } // namespace tir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index ec1f7f3753d1..dd289b691502 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -28,6 +28,7 @@ namespace script { namespace ir_builder { namespace tir { +using tvm::runtime::NDArray; using tvm::tir::Buffer; using tvm::tir::Var; @@ -317,6 +318,87 @@ LetFrame Let(Var var, PrimExpr value); */ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); +/*! + * \brief The allocate node. + * \param extents The extents of the allocate. + * \param dtype The data type of the buffer. + * \param storage_scope The storage scope. + * \param condition The condition. + * \param annotations Additional annotation hints. + * \return The created AllocateFrame. + */ +AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", + Optional condition = NullOpt, + Optional> annotations = NullOpt); + +/*! + * \brief The allocate constant node. + * \param data The data associated with the constant. + * \param dtype The data type of the buffer. + * \param extents The extents of the allocate. + * \param annotations Additional annotation hints. + * \return The created AllocateConstFrame. + */ +AllocateConstFrame AllocateConst( + NDArray data, DataType dtype, Array extents, + Map annotations = NullValue>()); + +/*! + * \brief Create an attribute. + * \param node The node to annotate the attribute. + * \param attr_key Attribute type key. + * \param value The value of the attribute. + * \return The result AttrFrame. + */ +AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value); + +/*! + * \brief Create a while loop. + * \param condition The termination condition of the loop. + * \return The result WhileFrame. + */ +WhileFrame While(PrimExpr condition); + +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ +IfFrame If(PrimExpr condition); + +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ +ThenFrame Then(); + +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ +ElseFrame Else(); + +/*! + * \brief The buffer declaration frame. + * \param shape The type of the buffer prior to flattening. + * \param dtype The data type in the content of the buffer. + * \param buffer_name The name of the buffer. + * \param data The pointer to the head of the data. + * \param strides The strides of each dimension. + * \param elem_offset The offset in terms of number of dtype elements (including lanes). + * \param storage_scope The optional storage scope of buffer data pointer. + * \param align The alignment requirement of data pointer in bytes. + * \param offset_factor The factor of elem_offset field. + * \param buffer_type The buffer type. + * \param axis_separators The separators between input axes when generating flattened output axes. + * \return The declared buffer. + */ +DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, + Optional data, Optional> strides, + Optional elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type, + Optional> axis_separators); + /*! * \brief Launch a thread. * \param var The iteration variable. @@ -332,6 +414,21 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); */ Var EnvThread(String thread_tag); +/*! + * \brief Store data in a buffer. + * \param buffer The buffer. + * \param value The value to be stored. + * \param indices The indices location to be stored. + */ +void BufferStore(Buffer buffer, PrimExpr value, Array indices); + +/*! + * \brief The prefetch hint for a buffer + * \param buffer The buffer to be prefetched. + * \param bounds The bounds to be prefetched. + */ +void Prefetch(Buffer buffer, Array bounds); + /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 69bc5bfc9676..b9b50dfa9876 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -18,7 +18,7 @@ from typing import List, Union from tvm._ffi import register_object as _register_object -from tvm.tir import Var +from tvm.tir import Buffer, Var from ..base import IRBuilderFrame @@ -65,6 +65,52 @@ class RealizeFrame(TIRFrame): ... +@_register_object("script.ir_builder.tir.AllocateFrame") +class AllocateFrame(TIRFrame): + def __enter__(self) -> Buffer: + super().__enter__() + return self.buffer + + +@_register_object("script.ir_builder.tir.AllocateConstFrame") +class AllocateConstFrame(TIRFrame): + def __enter__(self) -> Buffer: + super().__enter__() + return self.buffer + + +@_register_object("script.ir_builder.tir.AttrFrame") +class AttrFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.WhileFrame") +class WhileFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.IfFrame") +class IfFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.ThenFrame") +class ThenFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.ElseFrame") +class ElseFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.DeclBufferFrame") +class DeclBufferFrame(TIRFrame): + def __enter__(self) -> Buffer: + super().__enter__() + return self.buffer + + @_register_object("script.ir_builder.tir.LaunchThreadFrame") class LaunchThreadFrame(TIRFrame): ... diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 6db8f40c32c8..625e1291ff20 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -19,8 +19,10 @@ from numbers import Integral from typing import Any, Dict, List, Optional, Union, Tuple +import numpy as np # type: ignore from tvm.ir import Range, Type +from tvm.runtime import convert, ndarray from tvm.tir import ( Buffer, BufferLoad, @@ -32,6 +34,7 @@ StringImm, Var, ) +from tvm.tir import Ramp as ramp from . import _ffi_api, frame @@ -890,6 +893,217 @@ def realize( ) +def allocate( + extents: List[PrimExpr], + dtype: str, + scope: str = "", + condition: PrimExpr = None, + annotations=None, +) -> frame.AllocateFrame: + """Allocate node. + + Parameters + ---------- + extents : List[PrimExpr] + The extents of the allocate. + + dtype : str + The data type of the buffer. + + scope : str + The storage scope. + + condition : PrimExpr + The condition. + + annotations: Optional[Mapping[str, Object]] + Additional annotation hints. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member + extents, dtype, scope, condition, annotations + ) + + +def allocate_const( + data: List[PrimExpr], + dtype: str, + extents: List[PrimExpr], + annotations=None, +) -> frame.AllocateConstFrame: + """Allocate constant node. + + Parameters + ---------- + data : List[PrimExpr] + The data associated with the constant. + + dtype : str + The data type of the buffer. + + extents : List[PrimExpr] + The extents of the allocate. + + annotations : Optional[Map] + Additional annotations about the allocation. + """ + + return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member + ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations + ) + + +def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: + """Create an attribute node. + + Parameters + ---------- + node : Any + The node to annotate the attribute. + + attr_key : str + Attribute type key. + + value : Union[PrimExpr, str] + The value of the attribute. + + Returns + ------- + res : frame.AttrFrame + The result AttrFrame. + """ + node = convert(node) + value = convert(value) + return _ffi_api.Attr(node, attr_key, value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name + """Create a while node. + + Parameters + ---------- + condition : PrimExpr + The termination condition of the loop. + + Returns + ------- + res : frame.WhileFrame + The result WhileFrame. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.While(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if node. + + Parameters + ---------- + condition : PrimExpr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + if isinstance(condition, bool): + condition = IntImm("bool", condition) + return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then. + + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # type: ignore[attr-defined] # pylint: disable=no-member + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else. + + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, +) -> frame.DeclBufferFrame: + """Create a buffer declaration node. + + Parameters + ---------- + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : frame.DeclBufferFrame + The result DeclBufferFrame. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member + shape, + dtype, + "", + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + def launch_thread( iter_var: IterVar, # pylint: disable=redefined-outer-name extent: PrimExpr, @@ -939,6 +1153,53 @@ def env_thread(thread_tag: str) -> IterVar: return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member +def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None: + """Buffer store node. + + Parameters + ---------- + buffer : Buffer + The buffer. + + value : PrimExpr + The value to be stored. + + indices : List[Union[PrimExpr, slice]] + The indices location to be stored. + """ + from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel + + expr_indices = [] + for index in indices: + if isinstance(index, slice): + step = 1 if index.step is None else index.step + lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step) + if lanes == 1: + expr_indices.append(index.start) + else: + expr_indices.append(ramp(index.start, step, int(lanes))) + else: + expr_indices.append(index) + if isinstance(value, bool) and buffer.dtype == "bool": + value = IntImm("bool", value) + return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member + buffer, value, expr_indices + ) + + +def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None: + """The prefetch hint for a buffer. + + Parameters + ---------- + buffer : Buffer + The buffer to be prefetched. + indices : List[PrimExpr] + The indices of the buffer to extract. + """ + return _ffi_api.Prefetch(buffer, indices) # type: ignore[attr-defined] # pylint: disable=no-member + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -1288,8 +1549,18 @@ def var(dtype, name="") -> Var: "Assert", "let", "realize", + "allocate", + "allocate_const", + "attr", + "While", + "If", + "Then", + "Else", + "decl_buffer", "launch_thread", "env_thread", + "buffer_store", + "prefetch", "evaluate", "int8", "int16", diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 6c9459e6389c..aa9efa653f71 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -115,6 +115,76 @@ void LaunchThreadFrameNode::ExitWithScope() { AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); } +void AllocateFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition, + AsStmt(stmts), annotations)); +} + +void AllocateConstFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent( + tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations)); +} +void AttrFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts))); +} + +void WhileFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::While(condition, AsStmt(stmts))); +} + +void IfFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + if (!stmts.empty()) { + LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame"; + } + if (!then_stmts.defined()) { + LOG(FATAL) << "IfThenElse frame should have at least one then branch"; + } + AddToParent(tvm::tir::IfThenElse( + condition, AsStmt(then_stmts.value()), + else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr))); +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("T.then_"); + if (frame->then_stmts.defined()) { + LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_stmts.value(); + } + TIRFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + FindIfFrame("T.then_")->then_stmts = stmts; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("T.else_"); + if (!frame->then_stmts.defined()) { + LOG(FATAL) << "The else branch should follow then branch"; + } + if (frame->else_stmts.defined()) { + LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_stmts.value(); + } + TIRFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + FindIfFrame("T.else_")->else_stmts = stmts; +} + +void DeclBufferFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts))); +} + TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); TVM_REGISTER_NODE_TYPE(BlockFrameNode); @@ -124,6 +194,14 @@ TVM_REGISTER_NODE_TYPE(AssertFrameNode); TVM_REGISTER_NODE_TYPE(LetFrameNode); TVM_REGISTER_NODE_TYPE(RealizeFrameNode); TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode); +TVM_REGISTER_NODE_TYPE(AllocateFrameNode); +TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode); +TVM_REGISTER_NODE_TYPE(AttrFrameNode); +TVM_REGISTER_NODE_TYPE(WhileFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); +TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode); } // namespace tir } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 5951af298f62..28c3d69861fa 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -444,6 +444,63 @@ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, return RealizeFrame(n); } +AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope, + Optional condition, Optional> annotations) { + ObjectPtr n = make_object(); + n->extents = extents; + n->dtype = dtype; + n->storage_scope = storage_scope; + n->condition = condition.value_or(tvm::Bool(true)); + n->annotations = annotations.value_or(Map()); + n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0, + "default", NullOpt); + return AllocateFrame(n); +} + +AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, + Array extents, Map annotations) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->extents = extents; + n->data = data; + n->annotations = annotations; + n->buffer = + BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt); + return AllocateConstFrame(n); +} + +AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) { + ObjectPtr n = make_object(); + n->node = node; + n->attr_key = attr_key; + n->value = value; + return AttrFrame(n); +} + +WhileFrame While(PrimExpr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + return WhileFrame(n); +} + +IfFrame If(PrimExpr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_stmts = NullOpt; + n->else_stmts = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + Var EnvThread(String thread_tag) { IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex, thread_tag); @@ -456,6 +513,25 @@ Var EnvThread(String thread_tag) { return var; } +void BufferStore(Buffer buffer, PrimExpr value, Array indices) { + AddToParent(tvm::tir::BufferStore(buffer, value, indices)); +} + +void Prefetch(Buffer buffer, Array bounds) { + AddToParent(tvm::tir::Prefetch(buffer, bounds)); +} + +DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, + Optional data, Optional> strides, + Optional elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type, + Optional> axis_separators) { + ObjectPtr n = make_object(); + n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, + align, offset_factor, buffer_type, axis_separators); + return DeclBufferFrame(n); +} + void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } using tvm::script::ir_builder::details::Namer; @@ -540,10 +616,20 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread); TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index c29fae1c65e9..733c975fad7e 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -88,6 +88,21 @@ inline BlockFrame FindBlockFrame(const String& method) { throw; } +/*! + * \brief Check whether the top frame in IRBuilder frame stack is IfFrame. + * \param method The method name to be printed when throwing exception. + * \return The top frame of IfFrame. + */ +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under T.if_()"; + } + throw; +} + /*! * \brief Convert BufferLoad to BufferRegion. * \param buffer_load The BufferLoad. diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 7f2e6e1a4706..40e13a2fbe2f 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -17,9 +17,11 @@ # pylint: disable=invalid-name, missing-docstring """Unittests for tvm.script.ir_builder.tir""" import pytest -import tvm.testing +import numpy as np import tvm +import tvm.testing from tvm import tir +from tvm.runtime import ndarray from tvm.script.ir_builder import tir as T from tvm.script.ir_builder import IRBuilder from tvm.ir.base import assert_structural_equal @@ -29,6 +31,7 @@ def test_ir_builder_tir_primfunc_base(): with IRBuilder() as ib: with T.prim_func(): T.evaluate(0) + # the prim_func generated by IRBuilder prim_func_actual = ib.get() @@ -41,6 +44,7 @@ def test_ir_builder_tir_primfunc_base(): preflattened_buffer_map=None, attrs=None, ) + # Check if the generated ir is expected assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True) @@ -58,6 +62,7 @@ def test_ir_builder_tir_primfunc_complete(): buffer_d = T.match_buffer(d, (64, 64), "int64") T.preflattened_buffer(e, (32, 32), "int8", data=e.data) T.evaluate(0) + # the prim_func generated by IRBuilder prim_func_actual = ib.get() @@ -83,6 +88,7 @@ def test_ir_builder_tir_primfunc_complete(): }, attrs=tvm.ir.make_node("DictAttrs", key="value"), ) + # Check if the generated ir is expected assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True) @@ -91,6 +97,7 @@ def test_ir_builder_tir_block_base(): with IRBuilder() as ib: with T.block("block"): T.evaluate(0) + # the block generated by IRBuilder block_realize_actual = ib.get() @@ -110,6 +117,7 @@ def test_ir_builder_tir_block_base(): predicate=True, block=block_expected, ) + # Check if the generated ir is expected assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) @@ -131,6 +139,7 @@ def test_ir_builder_tir_block_complete(): T.match_buffer(e[0:32, 0:32], (32, 32), "float32") T.axis.spatial(128, f) T.evaluate(0) + # the block generated by IRBuilder block_realize_actual = ib.get() @@ -158,6 +167,7 @@ def test_ir_builder_tir_block_complete(): predicate=var_a > 1, block=block_expected, ) + # Check if the generated ir is expected assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) @@ -201,6 +211,7 @@ def test_ir_builder_tir_axis(): predicate=True, block=block_expected, ) + # Check if the generated ir is expected assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) @@ -256,6 +267,7 @@ def test_ir_builder_tir_for(): kind=tir.ForKind.SERIAL, body=parallel_expected, ) + # Check if the generated ir is expected assert_structural_equal(for_actual, for_expected, map_free_vars=True) @@ -271,20 +283,9 @@ def test_ir_builder_tir_assert(): assert_expected = tir.AssertStmt( T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0) ) - # Check if the generated ir is expected - assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) - -def test_ir_builder_tir_evaluate(): - with IRBuilder() as ib: - T.evaluate(0) - # the evaluate generated by IRBuilder - eval_actual = ib.get() - - # the expected evaluate - eval_expected = tir.Evaluate(0) # Check if the generated ir is expected - assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) + assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) def test_ir_builder_tir_let(): @@ -296,6 +297,8 @@ def test_ir_builder_tir_let(): # the expected Let statement let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0)) + + # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) @@ -304,6 +307,8 @@ def test_ir_builder_tir_realize(): with IRBuilder() as ib: with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True): T.evaluate(0) + + # the buffer realization generated by IRBuilder realize_actual = ib.get() # the expected buffer realization @@ -313,6 +318,8 @@ def test_ir_builder_tir_realize(): expected_realize = tir.AttrStmt( buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize ) + + # Check if the generated ir is expected assert_structural_equal(realize_actual, expected_realize, map_free_vars=True) @@ -322,12 +329,152 @@ def test_ir_builder_tir_thread(): brow = T.env_thread("blockIdx.y") with T.launch_thread(brow, 1): T.evaluate(0) + + # the prim_func generated by IRBuilder ir_actual = ib.get() + + # the expected prim_func iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0)) func = tir.PrimFunc([], attr_stmt) + + # Check if the generated ir is expected assert_structural_equal(ir_actual, func, map_free_vars=True) +def test_ir_builder_tir_allocate(): + with IRBuilder() as ib: + with T.allocate([10], "float32", scope="local"): + T.evaluate(1) + + # the allocate generated by IRBuilder + ir_actual = ib.get() + + # the expected allocate + buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local")) + ir_expected = tir.Allocate( + buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1) + ) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_allocate_const(): + data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + with IRBuilder() as ib: + with T.allocate_const(data, "int32", [10]): + T.evaluate(1) + + # the allocate const generated by IRBuilder + ir_actual = ib.get() + + # the expected allocate const + buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32"))) + ir_expected = tir.AllocateConst( + buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1) + ) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_while(): + with IRBuilder() as ib: + with T.While(T.var("int32", "x") > 0): + T.evaluate(0) + + # the while generated by IRBuilder + ir_actual = ib.get() + + # the expected while + ir_expected = tir.While(tir.Var("x", "int32") > 0, tir.Evaluate(0)) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_if_then_else(): + with IRBuilder() as ib: + with T.If(T.var("int32", "c") < 12): + with T.Then(): + T.evaluate(T.int32(0)) + with T.Else(): + T.evaluate(T.int32(1)) + + # the if_then_else generated by IRBuilder + ir_actual = ib.get() + + # the expected if_then_else + ir_expected = tir.IfThenElse( + tir.Var("c", "int32") < 12, + tir.Evaluate(tir.IntImm("int32", 0)), + tir.Evaluate(tir.IntImm("int32", 1)), + ) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_buffer_store(): + buffer_a = T.buffer_decl((10, 10), "float32") + i = T.var("int32", "x") + with IRBuilder() as ib: + T.buffer_store(buffer_a, 0.1, [0, i]) + + # the buffer store generated by IRBuilder + ir_actual = ib.get() + + # the expected buffer store + ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i]) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_prefetch(): + with IRBuilder() as ib: + buffer_a = T.buffer_decl((128, 128), "float32") + T.prefetch(buffer_a, []) + + # the prefetch generated by IRBuilder + ir_actual = ib.get() + + # the expected prefetch + ir_expected = tir.Prefetch(buffer_a, []) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + +def test_ir_builder_tir_evaluate(): + with IRBuilder() as ib: + T.evaluate(0) + # the evaluate generated by IRBuilder + eval_actual = ib.get() + + # the expected evaluate + eval_expected = tir.Evaluate(0) + + # Check if the generated ir is expected + assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) + + +def test_ir_builder_tir_decl_buffer(): + with IRBuilder() as ib: + with T.decl_buffer([128, 128], "float32"): + T.evaluate(0) + + # the decl_buffer generated by IRBuilder + ir_actual = ib.get() + + # the expected decl_buffer + buffer = T.buffer_decl((128, 128), "float32") + ir_expected = tir.DeclBuffer(buffer, tir.Evaluate(0)) + + # Check if the generated ir is expected + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()