diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index b9e81c0a55..9de444e67e 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -101,7 +101,7 @@ class IntSet : public ObjectRef { /*! * \brief Try to match IntSet with range r. * - * \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true + * \note It is guaranteed that IntSet::FromRange(r).MatchRange(r) == true * \return true if we can prove they are the same. */ bool MatchRange(const tvm::Range& r) const; diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 51e39eb8b6..dab6594951 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -191,7 +191,7 @@ class IterSplitExpr : public IterMapExpr { */ TVM_DLL explicit IterSplitExpr(IterMark source); /*! - * \brief constructor from just source. + * \brief constructor from source and scale. * \param source The source expression. * \param scale The additional scaling factor. */ diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index a26acb1f7a..bb7a9de1da 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); /*! - * \brief Auto detect the block read/write region according to body stmt - * It will detect the read/write region as an array in order of appearance in AST + * \brief Auto detect the block access region according to its body stmt + * It will detect the access region as an array in order of appearance in AST * \param block The block to be detected * \param buffer_var_map The outside buffers which may be accessed the block. * It is a map from buffer var to the buffer. @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain * - second: write regions * - third: opaque regions */ -Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL Array> GetBlockAccessRegion(const Block& block, + const Map& buffer_var_map); + +/*! + * \brief Auto detect the block read/write region according to its body stmt. An opaque access will + * be counted as both a read and a write access + * \param block The block to be detected + * \param buffer_var_map The outside buffers which may be accessed the block. + * It is a map from buffer var to the buffer + * \return An array only consisting of the read regions and write regions of the input block + */ +TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map); /*! * \brief Calculate the expresion complexity based on number of symbols it contains. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 081c05bd69..cb256bde2d 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -671,6 +671,11 @@ TVM_DLL const Op& atomic_add(); */ TVM_DLL const Op& tvm_memcpy_async(); +/*! + * \brief tvm intrinsic for mfma instruction + */ +TVM_DLL const Op& tvm_mfma_sync(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 47f8477dc1..fd601a7a09 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -23,6 +23,9 @@ #include namespace tvm { + +class Target; + namespace tir { using TRandState = support::LinearCongruentialEngine::TRandState; @@ -239,6 +242,18 @@ class ScheduleNode : public runtime::Object { */ virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) = 0; + /*! + * \brief Sample the factors to tile a list of LoopRV's + * \param loop_rvs The loops to be tiled + * \param ns The number of loops after tiling + * \param max_innermost_factor The maximum factor in the innermost loop, -1 if disabled + * \param decision The sampling decision + * \return An array of n random variables, the result of sampling + */ + virtual Array> SampleShapeGenericTiles(const Array& loop_rvs, + const std::vector& ns, + const Target& target, int max_innermost_factor, + Optional>> decision = NullOpt) = 0; /*! * \brief Sample an integer given the probability distribution * \param candidates The candidates diff --git a/include/tvm/tir/sparse/block.h b/include/tvm/tir/sparse/block.h new file mode 100644 index 0000000000..9061449de9 --- /dev/null +++ b/include/tvm/tir/sparse/block.h @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/sparse/block.h + * \brief Sparse Block in Sparse TIR. + */ + +#ifndef TVM_TIR_SPARSE_BLOCK_H_ +#define TVM_TIR_SPARSE_BLOCK_H_ + +#include + +namespace tvm { + +namespace tir { + +namespace sparse { + +/*! + * \brief Class of sparse block. + * \example + * with tir.sp.block([i, j, k], [False, False, True]) as [vi, vj, vk]: + * pass + * with tir.sp.block([i, j, k], [False, False, True], [(0, 1), (2,)]) as [vi, vj, vk]: + * pass + */ +class SparseBlockNode : public Object { + public: + AxisRef root; + Array axes; + Array> fused_groups; + Array is_reduce_axis; + static constexpr const char* _type_key = "tir.sp.SparseBlockNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, Object); +}; + +class SparseBlock : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, ObjectRef, SparseBlockNode); +} + + +} // namespace sparse + +} // namespace tir + +} // namespace tvm + +#endif // TVM_TIR_SPRASE_BLOCK_H_ diff --git a/include/tvm/tir/sparse/buffer.h b/include/tvm/tir/sparse/buffer.h new file mode 100644 index 0000000000..694de5f04a --- /dev/null +++ b/include/tvm/tir/sparse/buffer.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/sparse/buffer.h + * \brief Sparse buffer data structure in Sparse TIR. + */ +#ifndef TVM_TIR_SPARSE_BUFFER_H_ +#define TVM_TIR_SPRASE_BUFFER_H_ + +#include +#include + +namespace tvm { + +namespace tir { + +namespace sparse { + +/*! + * \brief Class of sparse buffer. + */ +class SparseBufferNode : public Object { + public: + /* Root of Axis Dependency Tree. */ + AxisRef root; + /* Axes */ + Array axes; + /* Number of dimensions */ + int ndim; + /* Buffer corresponding to flattened value */ + Buffer data; + /* Buffer corresponding to indices pointer */ + Array indptr; + /* Buffer of column indices */ + Array indices; + + static constexpr const char* _type_key = "tir.sp.SparseBufferNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object); +}; + +/*! + * \brief Managed reference to SparseBufferNode. + * \sa SparseBufferNode + */ +class SparseBuffer : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); +}; + +} // namespace sparse + +} // namespace tir + +} // namespace tvm + +#endif // TVM_TIR_SPARSE_BUFFER_H_ diff --git a/include/tvm/tir/sparse/format.h b/include/tvm/tir/sparse/format.h new file mode 100644 index 0000000000..da145815df --- /dev/null +++ b/include/tvm/tir/sparse/format.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/sparse/format.h + * \brief Sparse format in Sparse TIR. + */ + +#ifndef TVM_TIR_SPARSE_FORMAT_H_ +#define TVM_TIR_SPRASE_FORMAT_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { + +namespace tir { + +namespace sparse { + +/*! + * \brief Base type for axis in sparse formats. + */ +class AxisNode : public Object { + public: + /* name of current axis. */ + String name; + /* length of current axis. For sparse axis, length refers to the upperbound of + * the current axis. */ + PrimExpr length; + static constexpr const char* _type_key = "Axis"; + TVM_DECLARE_BASE_OBJECT_INFO(AxisNode, Object); +}; + +/*! + * \brief Managed reference to AxisNode. + * \sa AxisNode + */ +class Axis : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Axis, ObjectRef, AxisNode); +}; + +/*! + * \brief Root of Axis Dependency Tree. + */ +class RootAxisNode : public Object { + public: + static constexpr const char* _type_key = "RootAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(RootAxisNode, Object); +}; + +/*! + * \brief Managed reference to RootAxisNode. + * \sa RootAxisNode + */ +class RootAxis : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RootAxis, ObjectRef, RootAxisNode); +}; + +/*! + * \brief Dense axis whose column indices are consecutive. + */ +class DenseAxisNode : public AxisNode { + public: + static constexpr const char* _type_key = "DenseAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(DenseAxisNode, AxisNode); +}; + +/*! + * \brief Managed reference to DenseAxisNode. + * \sa DenseAxisNode + */ +class DenseAxis : public Axis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DenseAxis, Axis, DenseAxisNode); +}; + +/*! + * \brief Dense axis with fixed length per row. + */ +class DenseFixedAxisNode : public DenseAxisNode { + public: + static constexpr const char* _type_key = "DenseFixedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); +}; + +/*! + * \brief Managed reference to DenseFixedAxisNode. + * \sa DenseFixedAxisNode + */ +class DenseFixedAxis : public DenseAxis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode); +}; + +class DenseVariableAxisNode : public DenseAxisNode { + public: + static constexpr const char* _type_key = "DenseVariableAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); +}; + +/*! + * \brief Dense axis whose length is dependent on its predecessors on the axis + * dependency tree. + */ +class DenseVariableAxis : public DenseAxis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, + DenseVariableAxisNode); +}; + +/*! + * \brief Sparse axis whose column indices is not consecutive. + */ +class SparseAxisNode : public AxisNode { + public: + static constexpr const char* _type_key = "SparseAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode); +}; + +/*! + * \brief Managed reference to SparseAxisNode. + * \sa SparseAxisNode + */ +class SparseAxis : public Axis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode); +}; + +/*! + * \brief Sparse axis with fixed number of non-zero columns per row. + */ +class SparseFixedAxisNode : public SparseAxisNode { + public: + /* (fixed) number of columns of current sparse axis. */ + PrimExpr num_cols; + static constexpr const char* _type_key = "SparseFixedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); +}; + +/*! + * \brief Managed reference to SparseFixedAxisNode. + * \sa SparseFixedAxisNode + */ +class SparseFixedAxis : public SparseAxis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, + SparseFixedAxisNode); +}; + +/*! + * \brief Sparse axis with variable number of non-zero columns per row. + */ +class SparseVariableAxisNode : public SparseAxisNode { + public: + static constexpr const char* _type_key = "SparseVariabledAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); +}; + +/*! + * \brief Managed reference to SparseVariableAxisNode. + * \sa SparseVariableAxisNode + */ +class SparseVariableAxis : public SparseAxis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, + SparseVariableAxisNode); +}; + +/*! + * \brief Reference of Axis on Axis Dependency Tree. + */ +class AxisRefNode : public Object { + public: + // parent refers to the parent axis of current axis tree. + Optional parent; + Axis axis; + Array children; + static constexpr const char* _type_key = "tir.sp.AxisRefNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(AxisRefNode, Object); +}; + +/*! + * \brief Managed reference to AxisRefNode. + * \sa AxisRefNode + */ +class AxisRef : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AxisRef, ObjectRef, AxisRefNode); +}; + +} // namespace sparse + +} // namespace tir + +} // namespace tvm + +#endif // TVM_TIR_SPRASE_FORMAT_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 62d51f9d38..4707b85aa0 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -534,11 +534,22 @@ TVM_DLL Pass FlattenBuffer(); TVM_DLL Pass UnifyThreadBinding(); /*! - * \brief Lower lower logical layout into physical layout. + * \brief Lower logical layout into physical layout. * \return The pass. */ TVM_DLL Pass LowerLogicalLayout(); +/*! + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. + * \return The IR transform pass. + */ +TVM_DLL Pass InjectSoftwarePipeline(); + +/*! + * \brief Lower logical intrinsics into physical intrinsics. + */ +TVM_DLL Pass LowerLogicalIntrin(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index ae3e9d885f..44c92b792f 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -57,7 +57,7 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match_buffers of the block, # which bind a sub-region of source buffer into a new buffer - D = tir.match_buffer_region(C[vi, vj]) + D = tir.match_buffer(C[vi, vj], ()) # init part of the block, executed when all reduce axes are the beginning value with tir.init(): @@ -65,13 +65,13 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # block body CC[0, 0] = A[vi, vk] * B[vj, vk] - D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] + D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] """ alloc_buffers: List[Buffer] = [] """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature""" + """List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature""" iter_bindings: Mapping[Var, PrimExpr] = {} """Mapping[Var, PrimExpr]: map of block iter var to its values""" reads: Optional[List[BufferSlice]] = None diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py index cfbc668946..d5c6dffc8d 100644 --- a/python/tvm/script/node.py +++ b/python/tvm/script/node.py @@ -96,11 +96,13 @@ def check_index(index: Union[int, PrimExpr]): if index < 0: report_error("Negative index is not allowed during buffer access", span) elif isinstance(index, PrimExpr): - if index.dtype != "int32": - report_error( - "index expected an int32 type PrimExpr but got " + str(index.dtype), - index.span, - ) + # FIXME(vinx13): Ramp is allowed when registering logical intrinsic implementatoins + # if index.dtype != "int32": + # report_error( + # "index expected an int32 type PrimExpr but got " + str(index.dtype), + # index.span, + # ) + pass else: report_error( "Unsupported index type, expected int or tvm.tir.PrimExpr, but got " diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 9acf21b6ba..50da4e8f13 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -176,7 +176,7 @@ def transform(self, node): self.current_lineno = self.base_lineno + node.lineno - 1 if hasattr(node, "col_offset"): self.current_col_offset = node.col_offset - + method = "transform_" + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) transform_res = visitor(node) @@ -850,7 +850,6 @@ def transform_Attr(self, node): 2. All other names `tvm.something` are lookup up in this current python namespace. """ - if isinstance(node.object, ast.Var): if node.object.id.name == "tir": func_name = "tir." + node.field.name @@ -867,6 +866,25 @@ def transform_Attr(self, node): ) else: raise e + elif isinstance(node.object, ast.Attr): + func_name = node.field.name + node = node.object + if isinstance(node.object, ast.Var): + if node.object.id.name == "tir" and node.field.name == 'sp': + func_name = "tir.sp." + func_name + res = Registry.lookup(func_name) + if res is not None: + return res + try: + return tvm.ir.op.Op.get(func_name) + except TVMError as e: + # Check if we got an attribute error + if e.args[0].find("AttributeError"): + self.report_error( + f"Unregistered function `{func_name}`.", node.field.span + ) + else: + raise e symbol = self.transform(node.object) if symbol is None: diff --git a/python/tvm/script/sparse.py b/python/tvm/script/sparse.py new file mode 100644 index 0000000000..8eea7e8297 --- /dev/null +++ b/python/tvm/script/sparse.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script Parser for Sparse Dialect Classes""" +# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements +# pylint: disable=relative-beyond-top-level +import synr +from synr import ast + +import tvm.tir + +from .special_stmt import SpecialStmt +from .registry import register + +@register +class MatchSparseBuffer(SpecialStmt): + def __init__(self): + def match_buffer( + param, + shape, + fmt, + idtype="int32", + dtype="float32", + span=None, + ): + pass + diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 5716e2274f..25af763574 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -113,7 +113,7 @@ class MatchBuffer(SpecialStmt): Match buffer from Buffer subregion .. code-block:: python - A = tir.match_buffer(, (128, 128), dtype="float32") + A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") """ def __init__(self): diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index aac5ac6ff2..5cb9f87c41 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -140,7 +140,29 @@ def get_block_access_region( - second: write regions - third: opaque regions """ - return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore + return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore + + +def get_block_read_write_region( + block: Block, buffer_var_map: Dict[Var, Buffer] +) -> List[List[BufferRegion]]: + """Auto detect the block read/write region according to its body stmt. + An opaque access will be counted as both a read and a write access + + Parameters + ---------- + block: tvm.tir.Block + The block in which we are detecting read/write regions. + + buffer_var_map : Dict[Var, Buffer] + The outside buffers which may access the block. Mapping from buffer var to the buffer + + Returns + ------- + result : List[List[BufferRegion]] + An array only consisting of the read regions and write regions of the input block + """ + return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 3003f59d93..1ece8466a1 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -22,6 +22,7 @@ from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String +from tvm.target import Target from tvm.tir import Block, For, IntImm, PrimFunc, IterVar, TensorIntrin from . import _ffi_api_schedule @@ -327,6 +328,23 @@ def sample_perfect_tile( decision, ) + def sample_shape_generic_tile( + self, + loops: List[LoopRV], + ns: int, + target: Target, + max_innermost_factor: int = 16, + decision: Optional[List[List[int]]] = None + ) -> List[List[ExprRV]]: + return _ffi_api_schedule.ScheduleSamplePerfectTile( # pylint: disable=no-member + self, + loops, + ns, + target, + max_innermost_factor, + decision, + ) + def sample_categorical( self, candidates: List[int], diff --git a/python/tvm/tir/sparse/__init__.py b/python/tvm/tir/sparse/__init__.py new file mode 100644 index 0000000000..75790a36e4 --- /dev/null +++ b/python/tvm/tir/sparse/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from .buffer import Buffer +from .format import Format +from .block import Block \ No newline at end of file diff --git a/python/tvm/tir/sparse/block.py b/python/tvm/tir/sparse/block.py new file mode 100644 index 0000000000..24d48a47fc --- /dev/null +++ b/python/tvm/tir/sparse/block.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Abstraction for Sparse Block Structures.""" +import tvm._ffi + +from tvm.runtime import Object + +@tvm._ffi.register_object("tir.sparse.Block") +class Block(Object): + """The sparse block object in Sparse TIR. + """ + pass diff --git a/python/tvm/tir/sparse/buffer.py b/python/tvm/tir/sparse/buffer.py new file mode 100644 index 0000000000..c6b93ff351 --- /dev/null +++ b/python/tvm/tir/sparse/buffer.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Abstraction for sparse data structures.""" +from numbers import Integral +import tvm._ffi + +from tvm._ffi.base import string_types +from tvm.runtime import Object, convert +from tvm.ir import PrimExpr, PointerType, PrimType +from . import _ffi_api + +@tvm._ffi.register_object("tir.Buffer") +class Buffer(Object): + """Symbolic sparse data buffer in TVM. + + Buffer provide a way to represent sparse data layout + specialization of data structure in TVM. + + See Also + -------- + decl_buffer : Declare a buffer + """ + pass + diff --git a/python/tvm/tir/sparse/format.py b/python/tvm/tir/sparse/format.py new file mode 100644 index 0000000000..801a39223a --- /dev/null +++ b/python/tvm/tir/sparse/format.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Sparse formats declaration.""" +import tvm._ffi + +from tvm.runtime import Object, convert +from .. import _ffi_api + + +@tvm._ffi.register_object("tir.sp.Format") +class Format(Object): + """Sparse formats in Sparse TIR. + """ + pass \ No newline at end of file diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index a8cd4b1d7a..43cff16a86 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -754,7 +754,7 @@ def LowerMatchBuffer(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerMatchBuffer() + return _ffi_api.LowerMatchBuffer() # type: ignore def FlattenBuffer(): @@ -812,3 +812,14 @@ def LowerLogicalLayout(): The result pass """ return _ffi_api.LowerLogicalLayout() + + +def LowerLogicalIntrin(): + """Lower logical intrinsics to physical intrinsics. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerLogicalIntrin() \ No newline at end of file diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 783779bd5d..58f452a748 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -224,6 +224,9 @@ Array CreatePassList(bool disable_loop_partition, bool for pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LowerLogicalLayout()); + pass_list.push_back(tir::transform::LowerLogicalIntrin()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::UnifyThreadBinding()); } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 118563f125..9be431bae2 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -581,8 +581,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { << Print(alloc_buf->shape) << ")" << Doc::NewLine(); } for (const auto& match_buf : block_op->match_buffers) { - body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source) - << ")" << Doc::NewLine(); + body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" + << Doc::NewLine(); } if (block_op->init.defined()) { Doc init_block; diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index a610f9d939..90aaa35d60 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -26,6 +26,7 @@ #include #include +#include "../transforms/ir_utils.h" namespace tvm { namespace tir { @@ -109,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) { ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { const Var& target_var = match_buffer->buffer->data; - match_buffers_[target_var.get()] = match_buffer; - buffer_var_map_.Set(target_var, match_buffer->buffer); + const Var& source_var = match_buffer->source->buffer->data; + if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) { + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } } StmtExprVisitor::operator()(stmt); } @@ -204,7 +208,7 @@ std::vector BlockReadWriteDetector::ConvertMatchedRegion( region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); } - region = match_buffer.ConvertRegion(region); + region = ConvertRegion(match_buffer, region); std::vector result; result.reserve(region.size()); @@ -281,7 +285,39 @@ Array> GetBlockAccessRegion(const Block& block, return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()}; } -TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion); +Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map) { + // Step 1. Get all the read/write/opaque accesses in the input block. + Array> access_regions = GetBlockAccessRegion(block, buffer_var_map); + // Step 2. Collect all the buffers that are opaquely accessed. + std::unordered_set opaque_accessed_buffers; + for (const BufferRegion& opaque_access : access_regions[2]) { + opaque_accessed_buffers.insert(opaque_access->buffer.get()); + } + // Step 3. Create new arrays of read/write regions. + Array new_read_regions; + Array new_write_regions; + new_read_regions.reserve(access_regions[0].size() + access_regions[2].size()); + new_write_regions.reserve(access_regions[1].size() + access_regions[2].size()); + for (const BufferRegion& read_access : access_regions[0]) { + if (!opaque_accessed_buffers.count(read_access->buffer.get())) { + new_read_regions.push_back(read_access); + } + } + for (const BufferRegion& write_access : access_regions[1]) { + if (!opaque_accessed_buffers.count(write_access->buffer.get())) { + new_write_regions.push_back(write_access); + } + } + for (const BufferRegion& opaque_access : access_regions[2]) { + new_read_regions.push_back(opaque_access); + new_write_regions.push_back(opaque_access); + } + return {new_read_regions, new_write_regions}; +} + +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 8f39de4c96..e680d68973 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -86,23 +86,16 @@ class LCADetector : public StmtExprVisitor { buffer_var_map_.emplace(buf->data.get(), buf.get()); } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); + auto* current_scope = arena_.make(parent_scope, op, n); + + ancestor_scopes_.push_back(current_scope); // Update match_buffers for (const MatchBufferRegion& match_buffer : op->match_buffers) { - const Buffer& target_buffer = match_buffer->buffer; - buffer_var_map_.emplace(target_buffer->data.get(), target_buffer.get()); - - const Buffer& source_buffer = match_buffer->source->buffer; - auto it = match_buffers_.find(source_buffer.get()); - if (it != match_buffers_.end()) { - match_buffers_[target_buffer.get()] = it->second; - } else { - match_buffers_[target_buffer.get()] = source_buffer.get(); - } + UpdateBufferLCA(match_buffer->source->buffer.get()); + match_buffers_.insert(match_buffer->buffer.get()); } - const ScopeInfo* parent_scope = ancestor_scopes_.back(); - auto* current_scope = arena_.make(parent_scope, op, n); - ancestor_scopes_.push_back(current_scope); StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } @@ -144,12 +137,11 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { - auto it = match_buffers_.find(buffer); - if (it != match_buffers_.end()) { - buffer = it->second; + if (match_buffers_.find(buffer) == match_buffers_.end()) { + // Ingore buffer created by block match_buffer + const ScopeInfo*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); } - const ScopeInfo*& lca = buffer_lca_[buffer]; - lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { @@ -184,7 +176,7 @@ class LCADetector : public StmtExprVisitor { /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; /*! \brief The match buffers inside blocks. */ - std::unordered_map match_buffers_ = {}; + std::unordered_set match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c15b3bb47b..f265a8ae2b 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,15 +36,12 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map, bool contain_root) - : buffer_var_map_(buffer_var_map), contain_root_(contain_root) {} + explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Whether the stmt contains at least one block. */ bool contains_block = false; private: Map* buffer_var_map_; - bool contain_root_; - bool visited_root_ = false; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; Stmt body = StmtMutator::VisitStmt_(op); @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator { } Stmt VisitStmt_(const BlockNode* op) override { - bool is_root_block = contain_root_ && !visited_root_; - visited_root_ = true; // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->Set(target_buffer->data, target_buffer); + } Block block = Downcast(StmtMutator::VisitStmt_(op)); // Remove buffers allocated inside block to detect its access region for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->erase(target_buffer->data); + } // Get access detection mask // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write int mask = 0; @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator { } // ignore root block or blocks which already has reads/writes regions if (mask != 0) { - if (op->iter_vars.empty()) { - // non-root opaque block is not allowed - CHECK(is_root_block) - << "ValueError: Can not auto detect buffer access region for an opaque block. Please " - "annotate the access region manually."; - return std::move(block); - } auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); const Array& reads = access_region[0]; const Array& writes = access_region[1]; @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } bool contain_root = root_allocates.empty() && func->body->IsInstance() && Downcast(func->body)->block->iter_vars.empty(); - ScriptCompleter script_completer(&buffer_var_map, contain_root); + ScriptCompleter script_completer(&buffer_var_map); // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index f4b1d87969..67f1a83825 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -802,7 +802,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->buffer->name << " = match_buffer_region("; + p->stream << op->buffer->name << " = match_buffer("; p->Print(op->source); p->stream << ")\n"; }); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 84b9af42af..8b2f5a2e4c 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -271,6 +271,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_pipeline_consumer_wait) TIR_DEFINE_BUILTIN_FUNC(tvm_pipeline_consumer_release) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(tvm_mfma_sync) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index b18c7c450f..255a30dc4b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -225,6 +225,28 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); } +Array> +ConcreteScheduleNode::SampleShapeGenericTiles(const Array& loop_rvs, + const std::vector& ns, + const Target& target, + int max_innermost_factor, + Optional>> decision) { + TVM_TIR_SCHEDULE_BEGIN(); + Array stmt_srefs; + for (const LoopRV& loop_rv : loop_rvs) { + stmt_srefs.push_back(GetSRef(loop_rv)); + } + std::vector> result = + tir::SampleShapeGenericTiles(state_, &this->rand_state_, stmt_srefs, ns, target, max_innermost_factor, + &decision); + Array> result_rvs; + for (const std::vector& sampled : result) { + result_rvs.push_back(CreateRV(sampled)); + } + return result_rvs; + TVM_TIR_SCHEDULE_END("sample-shape-generic-tile", this->error_render_level_); +} + ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index a3bac8558d..1b76f51864 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -91,6 +91,11 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Sampling ********/ Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; + Array> SampleShapeGenericTiles(const Array& loop_rvs, + const std::vector& ns, + const Target& target, + int max_innermost_factor, + Optional>> decision = NullOpt) override; ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) override; LoopRV SampleComputeLocation(const BlockRV& block_rv, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 6ef25889c4..9bdea788d0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -128,6 +128,13 @@ TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, tir::TRa const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision); +TVM_DLL std::vector> SampleShapeGenericTiles(tir::ScheduleState self, + tir::TRandState* rand_state, + const Array& loop_srefs, + const std::vector& ns, + const Target& target, + int max_innermost_factor, + Optional>>* decision); TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, tir::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 71cdcb9653..629afef313 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -375,13 +375,28 @@ Array> TrivialSubspaceDivision(const Array& iter bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); bool inner = UsesVar( bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); + bool is_var = bindings[i]->IsInstance(); if (outer && !inner) { - arith::IterMark outer(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + arith::IterMark outer{nullptr}; + if (is_var) { + outer = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + outer = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } arith::IterMark inner(arith::IterSumExpr({}, 0), 1); res.push_back(Array({outer, inner})); } else if (inner && !outer) { + arith::IterMark inner{nullptr}; + if (is_var) { + inner = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + inner = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } arith::IterMark outer(arith::IterSumExpr({}, 0), 1); - arith::IterMark inner(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); res.push_back(Array({outer, inner})); } else if (!outer && !inner) { arith::IterMark outer(arith::IterSumExpr({}, 0), 1); diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index f90370714e..f7d7a02201 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -409,7 +409,7 @@ class BaseInliner : public StmtExprMutator { Array reads = std::move(block->reads); Array writes = std::move(block->writes); if (!is_scope_root) { - Array> inspected = GetBlockAccessRegion(block, buffer_var_map_); + Array> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); reads = std::move(inspected[0]); writes = std::move(inspected[1]); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 02a16008c9..6f6490798f 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -606,6 +606,29 @@ std::vector SamplePerfectTile(tir::ScheduleState self, TRandState* rand return result; } +std::vector> SampleShapeGenericTiles(tir::ScheduleState self, + TRandState* rand_state, + const Array& loop_srefs, + const std::vector& ns, + const Target& target, + int max_innermost_factor, + Optional>>* decision) { + std::vector extents; + for (const StmtSRef& loop_sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + extents.push_back(GetLoopIntExtent(loop)); + } + std::vector> sampled_tiles = + SampleShapeGenericTiles(rand_state, ns, extents, target, max_innermost_factor); + std::vector> result; + *decision = Array>(); + for (const std::vector& sampled : sampled_tiles) { + result.emplace_back(sampled.begin(), sampled.end()); + decision->value().push_back(AsArray(result.back())); + } + return result; +} + int64_t SampleCategorical(tir::ScheduleState self, TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { @@ -701,6 +724,43 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleShapeGenericTiles"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumDecisions = 1; + + static Array> UnpackedApplyToSchedule(Schedule sch, Array loop_rvs, + Array ns, Target target, + Integer max_innermost_factor, + Optional>> decision) { + std::vector n_splits; + for (const Integer& n : ns) { + n_splits.push_back(n->value); + } + return sch->SampleShapeGenericTiles(loop_rvs, n_splits, target, max_innermost_factor->value, + decision); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs, Array ns, + Target target, Integer max_innermost_factor, + Optional>> decision) { + PythonAPICall py("sample_shape_generic_tiles"); + py.Input("loops", loop_rvs); + py.Input("ns", ns); + py.Input("max_innermost_factor", max_innermost_factor->value); + py.Input("target", target->str()); + py.Decision(decision); + py.OutputList(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + struct SampleCategoricalTraits : public UnpackedInstTraits { static constexpr const char* kName = "SampleCategorical"; static constexpr bool kIsPure = true; diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 9f7ad74baa..e6549c89c4 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -63,6 +63,31 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } +Array> TracedScheduleNode::SampleShapeGenericTiles(const Array& loop_rvs, + const std::vector& ns, + const Target& target, + int max_innermost_factor, + Optional>> decision) { + Array stmt_srefs; + for (const LoopRV& loop_rv : loop_rvs) { + stmt_srefs.push_back(GetSRef(loop_rv)); + } + std::vector> result = + tir::SampleShapeGenericTiles(state_, &this->rand_state_, stmt_srefs, ns, target, max_innermost_factor, + &decision); + Array> result_rvs; + for (const std::vector& sampled : result) { + result_rvs.push_back(CreateRV(sampled)); + } + static const InstructionKind& kind = InstructionKind::Get("SampleShapeGenericTiles"); + trace_->Append(Instruction(kind, + {loop_rvs}, + {AsArray(ns), target, Integer(max_innermost_factor)}, + {result_rvs.begin(), result_rvs.end()}), + decision); + return result_rvs; +} + ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8b69700254..4c40783f71 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -53,6 +53,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Sampling ********/ Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; + Array> SampleShapeGenericTiles(const Array& loop_rvs, + const std::vector& ns, + const Target& target, int max_innermost_factor, + Optional>> decision = NullOpt) final; ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) final; LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index d030fdddab..e926bbaf8c 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -39,7 +39,7 @@ namespace tir { using namespace support; /*! - * \brief return the region collected by NDIntSet. return the oroginal buffer shape if the + * \brief return the region collected by NDIntSet. return the original buffer shape if the * int_set is empty. */ Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index 4c5e1dd512..1ca16780e9 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -78,7 +78,7 @@ class OpaqueBlockConverter : public StmtExprMutator { return std::move(new_realize); } - /*! \brief The map from block vars to thier binding values. */ + /*! \brief The map from block vars to their binding values. */ std::unordered_map var_substitutes_; }; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 87a7681729..acd1a2f0c3 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -99,7 +99,7 @@ class BufferFlattener : public StmtExprMutator { for (const auto& annotation : op->annotations) { const String& ann_key = annotation.first; const ObjectRef& ann_value = annotation.second; - if (attr::IsPragmaKey(ann_key)) { + if (attr::IsPragmaKey(ann_key) || ann_key == attr::pipeline_scope ) { body = AttrStmt(op->loop_var, ann_key, Downcast(ann_value), std::move(body)); } } @@ -108,7 +108,13 @@ class BufferFlattener : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); - return store->buffer.vstore(store->indices, store->value); + Array indices = store->indices; + if (indices.size()) { + if (const auto* ramp = indices.back().as()) { + indices.Set(indices.size() - 1, ramp->base); + } + } + return store->buffer.vstore(indices, store->value); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -127,7 +133,15 @@ class BufferFlattener : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); - return load->buffer.vload(load->indices, load->dtype); + Array indices = load->indices; + DataType dtype = load->dtype; + if (indices.size()) { + if (const auto* ramp = indices.back().as()) { + dtype = dtype.with_lanes(ramp->lanes); + indices.Set(indices.size() - 1, ramp->base); + } + } + return load->buffer.vload(indices, dtype); } PrimExpr VisitExpr_(const CallNode* op) final { diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index e873524250..562699efb3 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -60,69 +60,6 @@ using SMap = std::unordered_map; template using SSet = std::unordered_set; -/*! \brief Information about a software pipeline that will be used in the transformation */ -struct PipelineInfo { - // Buffers written by the producers. These buffers can only be read by the consumers. - SSet producer_buffers; - // Producers of the pipeline. - Array producers; - // Consumers of the pipeline. - Array consumers; - // Storage scope of the pipeline. The scope is the same as the storage scope of the producer - // buffers. Producer buffers are required to have the same storage scope. - String scope; - // Number of stages of the pipeline. - Integer num_stages; - // The loop variable of the pipelined loop. - Var loop_var; - // Buffer allocations that need to be relocated outside of the pipeline after the transformation. - Array buffer_allocs; - - PipelineInfo(const SSet& producer_buffers, const Array& producers, - const Array& consumers, const String& scope, const Integer& num_stages, - const Var& loop_var, const Array& buffer_allocs) - : producer_buffers(producer_buffers), - producers(producers), - consumers(consumers), - scope(scope), - num_stages(num_stages), - loop_var(loop_var), - buffer_allocs(buffer_allocs) {} -}; - -/* \brief Information about a buffer allocation. - * \note In TIR, a buffer allocation is consist of one or more AttrStmt followed by Allocate. - * This structure holds reference of these statements so that it can be used to rebuild the buffer - * allocation during the software pipeline transformaion. - */ -struct BufferInfo { - // The first AttrStmt related to the buffer. - Stmt annotation; - // The Allocate statement of the buffer. - Allocate allocate; - // The storage scope of the buffer. - String scope; - BufferInfo(const Stmt& annotation, const Allocate& allocate, const String& scope) - : annotation(annotation), allocate(allocate), scope(scope) {} -}; - -/*! - * \brief Strips AttrStmt of the buffer and get the closest nested Allocate. - * \param attr_node The AttrStmt related to the buffer. - */ -static Allocate GetBufferAllocate(const AttrStmtNode* attr_node) { - while (attr_node) { - ICHECK(attr_node->attr_key == tir::attr::storage_scope || - attr_node->attr_key == tir::attr::double_buffer_scope); - if (attr_node->body.as()) { - return Downcast(attr_node->body); - } - attr_node = attr_node->body.as(); - } - ICHECK(false) << "unreachable"; - throw; -} - struct BufferAccess { // Buffer variables being written. SSet writes; @@ -136,22 +73,12 @@ struct BufferAccess { BufferAccess GetBufferAccess(const Stmt& stmt) { BufferAccess access; PreOrderVisit(stmt, [&access](const ObjectRef& obj) { - if (const auto* store = obj.as()) { - access.writes.insert(store->buffer_var); - } else if (const auto* load = obj.as()) { - access.reads.insert(load->buffer_var); - } else if (const auto* call = obj.as()) { - if (call->op.same_as(builtin::tvm_access_ptr())) { - ICHECK(call->args.size() == 5U); - Var buffer_var = Downcast(call->args[1]); - int64_t rw_mask = Downcast(call->args[4])->value; - if (rw_mask & 1) { - access.reads.insert(buffer_var); - } - if (rw_mask & 2) { - access.writes.insert(buffer_var); - } - return false; + if (const auto* block = obj.as()) { + for (const auto& read : block->reads) { + access.reads.insert(read->buffer->data); + } + for (const auto& write : block->writes) { + access.writes.insert(write->buffer->data); } } return true; @@ -159,14 +86,72 @@ BufferAccess GetBufferAccess(const Stmt& stmt) { return access; } +struct PipelineBufferInfo { + Buffer new_buffer; + Var loop_var; + PipelineBufferInfo(Buffer new_buffer, Var loop_var) + : new_buffer(std::move(new_buffer)), loop_var(std::move(loop_var)) {} +}; + /*! - * \brief Detect the annotated pipeline loop and generate information that will be used for the - * software pipeline transformation later. + * \brief Use the pipeline information produced by PipelineDetector to transform the IR. + * + * Given a for-loop annotated with pipeline_scope, this pass does the following transformation. + * + * Input: + * \code + * for ax in range(min, min + extent, annotations={pipeline_scope: num_stages}): + * buffer allocations; + * producers(ax); // producers(ax) denotes ax-th iteration of the producers + * consumers(ax); // consumers(ax) denotes ax-th iteration of the consumers + * \endcode + * + * Output: + * \code + * + * buffer allocations; + * + * // prologue + * for ax in range(min, min + shift_extent): + * producers(ax); + * + * // main loop + * for ax in range(min, min + extent + shift_extent, annotations={pipeline_scope: 1}): + * producers(ax + shift_extent); + * consumers(ax); + * + * // epilogue + * for ax in range(min, min + shift_extent): + * consumers(ax + extent - shift_extent); + * + * where shift_extent = num_stages - 1 + * \endcode + * + * Synchronizatons and native pipeline API calls are inserted if needed. The main loop is annotated + * with AttrStmt so that `ThreadStorageSync` pass will skip this loop which prevents unnecessary + * synchronizations being inserted. + * + * Since producers are executed ahead of the consumers by `shift_extent` iterations, buffers written + * by the producers need to be enlarged by `num_stages` times. During iterations, results of the + * producers up to `num_stages` iterations will be kept in the buffer. This reduces synchronizations + * needed between the producers and the consumers so that they can be executed concurrently. */ -class PipelineDetector : public StmtVisitor { +class PipelineInjector : public StmtExprMutator { public: - SMap buffer_info_; - SMap pipeline_info_; + static Stmt Inject(bool use_native_pipeline, const PrimFunc& func) { + // detector(stmt); + PipelineInjector injector(use_native_pipeline, func); + Stmt new_stmt = injector(func->body); + return ConvertSSA(new_stmt); + } + + PipelineInjector(bool use_native_pipeline, const PrimFunc& func) : use_native_pipeline_(use_native_pipeline) { + DetectNativePipeline(); + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + } private: /*! @@ -198,250 +183,146 @@ class PipelineDetector : public StmtVisitor { } } - /*! - * \brief Make the plan for the pipeline transformation for a AST subtree. - * \param pipeline_scope The AttrStmt that annotates the for-loop for software pipelining. - * - * This function analyzes the dependencies among the children of the software pipelined for-loop, - * generates and stores the information of the pipeline in `pipeline_info_`. - */ - - void PlanPipeline(const AttrStmtNode* pipeline_scope) { - CHECK(current_pipeline_scope_ == nullptr) << "ValueError: Nested pipeline is not allowed."; - current_pipeline_scope_ = pipeline_scope; - StmtVisitor::VisitStmt_(pipeline_scope); - current_pipeline_scope_ = nullptr; - - Integer num_stages = Downcast(pipeline_scope->value); - CHECK_GE(num_stages->value, 2) << "ValueError: Pipeline should have at least two stages."; - - const ForNode* op = TVM_TYPE_AS(op, pipeline_scope->body, ForNode); - // The body of the annotated pipeline for-loop should be optional buffer allocations followed by - // SeqStmt. - Array buffer_allocs; - Stmt stmt = GetRef(op); - const auto* attr_node = op->body.as(); - while (attr_node) { - Allocate alloc = GetBufferAllocate(attr_node); - buffer_allocs.push_back(alloc->buffer_var); - stmt = alloc->body; - attr_node = stmt.as(); - } - const SeqStmtNode* body = stmt.as(); - CHECK(body) << "ValueError: The body of the pipeline should be SeqStmt."; - + std::pair, Array> GetPipelineProducerConsumers(const SeqStmt& seq) { // Build the dependency graph from buffer accesses. - // A map from a Stmt to its buffer access info. SMap buffer_access; // A map from a Stmt to its dependants. SMap> dep_src2dst; // A map from a Stmt to its dependencies. SMap> dep_dst2src; - BuildDependencyGraph(body, &buffer_access, &dep_src2dst, &dep_dst2src); + BuildDependencyGraph(seq.get(), &buffer_access, &dep_src2dst, &dep_dst2src); // analyze dependencies among direct children of the pipeline loop Array producers, consumers; - for (const auto& stmt : body->seq) { + for (const auto& stmt : seq->seq) { if (!dep_src2dst.count(stmt)) { consumers.push_back(stmt); } else { producers.push_back(stmt); } } - // Find buffers that are written by producers and read by consumers. - // These buffers need to be resized. - SSet producer_buffers; - for (const Stmt& consumer : consumers) { - for (const Stmt& dependency : dep_dst2src[consumer]) { - for (const Var& read : buffer_access.at(consumer).reads) { - if (buffer_access.at(dependency).writes.count(read)) { - producer_buffers.insert(read); - } - } - } - } - - CHECK(!producers.empty()) << "ValueError: Producer not found in the pipeline."; - CHECK(!consumers.empty()) << "ValueError: Consumer not found in the pipeline."; - CHECK(!producer_buffers.empty()) << "ValueError: Producer buffer not found in the pipeline."; - - // Check the consistency of pipeline scope. - String scope = buffer_info_.at(*producer_buffers.begin()).scope; - for (const Var& buffer : producer_buffers) { - CHECK_EQ(buffer_info_.at(buffer).scope, scope) << "ValueError: Inconsistent scopes among " - "buffers of pipeline producers"; - } - pipeline_info_.emplace(GetRef(pipeline_scope), - PipelineInfo{producer_buffers, producers, consumers, scope, num_stages, - op->loop_var, buffer_allocs}); + return {producers, consumers}; } - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::pipeline_scope) { - PlanPipeline(op); - return; - } - - StmtVisitor::VisitStmt_(op); - AttrStmt attr_stmt = GetRef(op); - if (op->attr_key == tir::attr::storage_scope) { - Allocate allocate = Downcast(op->body); - buffer_info_.emplace(allocate->buffer_var, - BufferInfo{attr_stmt, allocate, Downcast(op->value)->value}); - } else if (op->attr_key == tir::attr::double_buffer_scope) { - buffer_info_.at(Downcast(op->node)).annotation = attr_stmt; + Buffer RewriteAllocBuffer(const Buffer& buffer, int num_stages) { + ObjectPtr new_buffer = make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), num_stages); + if (new_buffer->strides.size()) { + PrimExpr stride_0 = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), new_buffer->strides); + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); } + return Buffer(new_buffer); } - const AttrStmtNode* current_pipeline_scope_ = nullptr; -}; - -/*! - * \brief Use the pipeline information produced by PipelineDetector to transform the IR. - * - * Given a for-loop annotated with pipeline_scope, this pass does the following transformation. - * - * Input: - * \code - * AttrStmt(pipeline_scope, num_stages) - * for ax in range(min, min + extent): - * buffer allocations; - * producers(ax); // producers(ax) denotes ax-th iteration of the producers - * consumers(ax); // consumers(ax) denotes ax-th iteration of the consumers - * \endcode - * - * Output: - * \code - * - * buffer allocations; - * - * // prologue - * for ax in range(min, min + shift_extent): - * producers(ax); - * - * // main loop - * AttrStmt(pipeline_scope, 1) - * for ax in range(min, min + extent + shift_extent): - * producers(ax + shift_extent); - * consumers(ax); - * - * // epilogue - * for ax in range(min, min + shift_extent): - * consumers(ax + extent - shift_extent); - * - * where shift_extent = num_stages - 1 - * \endcode - * - * Synchronizatons and native pipeline API calls are inserted if needed. The main loop is annotated - * with AttrStmt so that `ThreadStorageSync` pass will skip this loop which prevents unnecessary - * synchronizations being inserted. - * - * Since producers are executed ahead of the consumers by `shift_extent` iterations, buffers written - * by the producers need to be enlarged by `num_stages` times. During iterations, results of the - * producers up to `num_stages` iterations will be kept in the buffer. This reduces synchronizations - * needed between the producers and the consumers so that they can be executed concurrently. - */ -class PipelineInjector : public StmtExprMutator { - public: - static Stmt Inject(bool use_native_pipeline, const Stmt& stmt) { - PipelineDetector detector; - detector(stmt); - PipelineInjector injector(use_native_pipeline, detector.pipeline_info_, detector.buffer_info_); - Stmt new_stmt = injector(stmt); - return ConvertSSA(new_stmt); - } - - PipelineInjector(bool use_native_pipeline, const SMap& pipeline_info, - const SMap& buffer_info) - : pipeline_info_(pipeline_info), - buffer_info_(buffer_info), - use_native_pipeline_(use_native_pipeline) { - DetectNativePipeline(); - for (const auto& kv : pipeline_info_) { - for (const auto& buffer : kv.second.producer_buffers) { - skip_allocs.emplace(buffer_info_.at(buffer).annotation); - } - } - } - - private: - Stmt BuildPipeline(const AttrStmt& pipeline_scope) { - const PipelineInfo* pipeline_info = &pipeline_info_.at(pipeline_scope); - std::swap(pipeline_info, current_pipeline_); - - For pipeline_loop = Downcast(pipeline_scope->body); - PrimExpr shift_extent = Integer(current_pipeline_->num_stages->value - 1); + Stmt RewritePipelineBody(Stmt stmt, const For& pipeline_loop, int num_stages, + const String& scope) { + Array producers, consumers; + CHECK(stmt->IsInstance()) + << "ValueError: The body of the pipeline should be SeqStmt."; + std::tie(producers, consumers) = GetPipelineProducerConsumers(Downcast(stmt)); + CHECK(!producers.empty()) << "ValueError: Producer not found in the pipeline."; + CHECK(!consumers.empty()) << "ValueError: Consumer not found in the pipeline."; + PrimExpr shift_extent = Integer(num_stages - 1); // Step 1: Initialize pipeline_var for native pipeline, which will be used in the native // pipeline API calls - if (use_native_pipeline_) { - pipeline_var_ = Var("pipeline", DataType::Handle()); + bool use_native_pipeline = use_native_pipeline_ && scope == "shared"; + if (use_native_pipeline) { + CHECK(!pipeline_var_.defined()) << "ValueError: Nested native pipeline not supported."; + pipeline_var_ = Var("pipeline", PrimType(DataType::Handle())); } // Step 2: Mutate children to rewrite pipeline buffer access. - Array producers, consumers; - for (const auto& stmt : current_pipeline_->producers) { - producers.push_back(VisitStmt(stmt)); - } - for (const auto& stmt : current_pipeline_->consumers) { - consumers.push_back(VisitStmt(stmt)); - } + producers.MutateByApply(std::bind(&PipelineInjector::VisitStmt, this, std::placeholders::_1)); + consumers.MutateByApply(std::bind(&PipelineInjector::VisitStmt, this, std::placeholders::_1)); // Step 3: Build each part of the pipeline - Stmt prologue = BuildPrologue(producers, pipeline_loop, shift_extent); - Stmt epilogue = BuildEpilogue(consumers, pipeline_loop, shift_extent); - Stmt main_loop = BuildMainLoop(producers, consumers, pipeline_loop, shift_extent); - // Annotate the main loop so that thread_storage_sync will skip this part - main_loop = AttrStmt(Stmt(), tir::attr::pipeline_scope, Integer(1), main_loop); + Stmt prologue = BuildPrologue(producers, pipeline_loop, shift_extent, use_native_pipeline); + Stmt epilogue = + BuildEpilogue(consumers, pipeline_loop, shift_extent, scope, use_native_pipeline); + Stmt main_loop = BuildMainLoop(producers, consumers, pipeline_loop, shift_extent, num_stages, + scope, use_native_pipeline); Array pipeline_seq; - if (use_native_pipeline_) { + if (use_native_pipeline) { pipeline_seq = {prologue, main_loop, epilogue}; } else { - pipeline_seq = {prologue, GetPipelineSync(), main_loop, epilogue}; + pipeline_seq = {prologue, GetPipelineSync(scope), main_loop, epilogue}; } - Stmt pipeline = SeqStmt(pipeline_seq); + Stmt pipeline = SeqStmt::Flatten(pipeline_seq); // Step 4: Create the native pipeline object if necessary - if (use_native_pipeline_) { + if (use_native_pipeline) { PrimExpr create_pipeline = Call(DataType::Handle(), builtin::tvm_create_pipeline(), {}); pipeline = LetStmt(pipeline_var_.value(), create_pipeline, pipeline); + pipeline_var_ = NullOpt; } - // Step 5: Add buffer allocation - std::vector allocs; - Stmt no_op = Evaluate(0); - for (const Var& buffer_var : current_pipeline_->buffer_allocs) { - Stmt stmt = buffer_info_.at(buffer_var).annotation; - while (const auto* attr_node = stmt.as()) { - allocs.push_back(AttrStmt(attr_node->node, attr_node->attr_key, attr_node->value, no_op)); - stmt = attr_node->body; - } - const auto* alloc_node = TVM_TYPE_AS(alloc_node, stmt, AllocateNode); - if (current_pipeline_->producer_buffers.count(buffer_var)) { - ICHECK(alloc_node->extents.size() == 1U); - PrimExpr new_extent = alloc_node->extents[0] * current_pipeline_->num_stages; - Allocate new_alloc(alloc_node->buffer_var, alloc_node->dtype, {new_extent}, - alloc_node->condition, no_op); - allocs.push_back(new_alloc); - } else { - Allocate new_alloc(alloc_node->buffer_var, alloc_node->dtype, alloc_node->extents, - alloc_node->condition, no_op); - allocs.push_back(new_alloc); - } + return pipeline; + } + + String GetPipelineScope(const Array& producer_buffers) { + CHECK(producer_buffers.size()) << "ValueError: Cannot find producer buffers."; + String scope = GetPtrStorageScope(producer_buffers[0]->data); + for (size_t i = 1; i < producer_buffers.size(); i++) { + String new_scope = GetPtrStorageScope(producer_buffers[i]->data); + CHECK_EQ(scope, new_scope) << "ValueError: Inconsistent storage scopes of producer buffers " + "of the software pipeline (" + << scope << " vs. " << new_scope << ")."; } + return scope; + } + + Stmt InjectPipeline(const ForNode* op) { + // Get and check annotation + Integer num_stages = Downcast(op->annotations.Get(attr::pipeline_scope).value()); + CHECK_GE(num_stages->value, 2) << "ValueError: Pipeline should have at least two stages."; - std::swap(pipeline_info, current_pipeline_); - pipeline_var_ = NullOpt; - return MergeNest(allocs, pipeline); + // Clear the pipeline annotation + For pipeline_loop = GetRef(op); + auto* pipeline_loop_node = pipeline_loop.CopyOnWrite(); + pipeline_loop_node->annotations.erase(attr::pipeline_scope); + + // Resize producer buffers for pipelined accesses + CHECK(pipeline_loop->body->IsInstance()) + << "ValueError: Cannot find buffer allocations inside the pipeline scope."; + + BlockRealize block_realize = Downcast(pipeline_loop->body); + String scope = GetPipelineScope(block_realize->block->alloc_buffers); + Array new_alloc_buffers; + for (const Buffer& alloc_buffer : block_realize->block->alloc_buffers) { + Buffer new_buffer = RewriteAllocBuffer(alloc_buffer, num_stages); + new_alloc_buffers.push_back(new_buffer); + buffer_map_.emplace(alloc_buffer, PipelineBufferInfo(new_buffer, op->loop_var)); + // buffer_data_to_buffer_.Set(new_buffer->data, new_buffer); + } + + CHECK(is_one(block_realize->predicate)) + << "ValueError: The body block of the software pipeline can not have predicates."; + CHECK(block_realize->block->match_buffers.empty()) << "ValueError: Pipeline body with match_buffer is not supported."; + + // Rewrite pipeline body + Stmt pipeline_body = + RewritePipelineBody(block_realize->block->body, pipeline_loop, num_stages, scope); + + auto new_block = Block({}, {}, {}, "", pipeline_body, NullOpt, new_alloc_buffers); + auto access = GetBlockReadWriteRegion(new_block, buffer_data_to_buffer_); + auto* new_block_ptr = new_block.CopyOnWrite(); + new_block_ptr->reads = access[0]; + new_block_ptr->writes = access[1]; + return BlockRealize({}, Bool(true), std::move(new_block)); } - Stmt GetPipelineSync() { - return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), - Array{StringImm(current_pipeline_->scope)})); + Stmt GetPipelineSync(String scope) { + return Evaluate( + Call(DataType::Int(32), builtin::tvm_storage_sync(), Array{StringImm(scope)})); } + Map buffer_data_to_buffer_; + std::unordered_map buffer_map_; + /*! * \brief Wrap a producer statement with native pipeline API calls. * @@ -487,21 +368,21 @@ class PipelineInjector : public StmtExprMutator { * tvm_pipeline_consumer_commit(pipeline); * \endcode */ - Stmt WrapNativeConsumer(const Stmt& consumer) { + Stmt WrapNativeConsumer(const Stmt& consumer, const String& scope) { ICHECK(use_native_pipeline_); ICHECK(pipeline_var_.defined()); Stmt consumer_wait = Evaluate( Call(DataType::Handle(), builtin::tvm_pipeline_consumer_wait(), {pipeline_var_.value()})); Stmt consumer_release = Evaluate( Call(DataType::Handle(), builtin::tvm_pipeline_consumer_release(), {pipeline_var_.value()})); - Stmt storage_sync = GetPipelineSync(); + Stmt storage_sync = GetPipelineSync(scope); return SeqStmt::Flatten(consumer_wait, storage_sync, consumer, consumer_release); } - Stmt BuildPrologue(const Array& producers, For pipeline_loop, - const PrimExpr& shift_extent) { + Stmt BuildPrologue(const Array& producers, For pipeline_loop, const PrimExpr& shift_extent, + bool use_native_pipeline) { Stmt producer = SeqStmt::Flatten(producers); - if (use_native_pipeline_) { + if (use_native_pipeline) { producer = WrapNativeProducer(producer); } PrimExpr new_loop_var = @@ -519,11 +400,11 @@ class PipelineInjector : public StmtExprMutator { } } - Stmt BuildEpilogue(const Array& consumers, For pipeline_loop, - const PrimExpr& shift_extent) { + Stmt BuildEpilogue(const Array& consumers, For pipeline_loop, const PrimExpr& shift_extent, + const String& scope, bool use_native_pipeline) { Stmt consumer = SeqStmt::Flatten(consumers); - if (use_native_pipeline_) { - consumer = WrapNativeConsumer(consumer); + if (use_native_pipeline) { + consumer = WrapNativeConsumer(consumer, scope); } PrimExpr new_loop_var = is_one(shift_extent) ? pipeline_loop->min : pipeline_loop->loop_var.copy_with_suffix(""); @@ -541,19 +422,20 @@ class PipelineInjector : public StmtExprMutator { } } - Stmt ScheduleMainLoop(const Array& producers, const Array& consumers) { + Stmt ScheduleMainLoop(const Array& producers, const Array& consumers, int num_stages, + const String& scope, bool use_native_pipeline) { // Schedule the execution of producers and consumers. Producers and consumers are assumed to be // independant and can be executed concurrently. The schedule can be target-dependant. - Stmt storage_sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), - {StringImm(current_pipeline_->scope)})); + Stmt storage_sync = + Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(scope)})); // default case: run producers and consumers sequentially. Stmt producer = SeqStmt::Flatten(producers); Stmt consumer = SeqStmt::Flatten(consumers); - if (use_native_pipeline_) { + if (use_native_pipeline) { producer = WrapNativeProducer(producer); - consumer = WrapNativeConsumer(consumer); + consumer = WrapNativeConsumer(consumer, scope); } - if (!use_native_pipeline_ || current_pipeline_->num_stages->value == 2) { + if (!use_native_pipeline_ || num_stages == 2) { return SeqStmt::Flatten(producer, consumer, storage_sync); } else { return SeqStmt::Flatten(producer, consumer); @@ -561,7 +443,8 @@ class PipelineInjector : public StmtExprMutator { } Stmt BuildMainLoop(const Array& producers, const Array& consumers, For pipeline_loop, - const PrimExpr& shift_extent) { + const PrimExpr& shift_extent, int num_stages, const String& scope, + bool use_native_pipeline) { ForNode* main_loop = pipeline_loop.CopyOnWrite(); main_loop->extent -= shift_extent; @@ -573,54 +456,63 @@ class PipelineInjector : public StmtExprMutator { Stmt shifted_producer = Substitute(producer, subst_map); shifted_producers.push_back(shifted_producer); } - main_loop->body = ScheduleMainLoop(shifted_producers, consumers); + main_loop->body = + ScheduleMainLoop(shifted_producers, consumers, num_stages, scope, use_native_pipeline); + // Annotate the main loop so that thread_storage_sync will skip this part + main_loop->annotations.Set(attr::pipeline_scope, Integer(1)); return pipeline_loop; } - Stmt VisitStmt_(const AttrStmtNode* op) { - // Skip allocate of pipeline buffers in the original TensorIR AST. These buffers should be - // allocated later outside the pipeline scope. - if (skip_allocs.count(GetRef(op))) { - Allocate alloc = GetBufferAllocate(op); - return VisitStmt(alloc->body); + Stmt VisitStmt_(const BufferStoreNode* op) { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_map_.find(op->buffer); + if (it != buffer_map_.end()) { + auto* n = store.CopyOnWrite(); + n->buffer = (*it).second.new_buffer; + n->indices.insert(n->indices.begin(), + indexmod(buffer_map_.at(op->buffer).loop_var, n->buffer->shape[0])); } - AttrStmt attr_stmt = GetRef(op); - if (pipeline_info_.count(attr_stmt)) { - Stmt new_stmt = BuildPipeline(attr_stmt); - return new_stmt; - } - return StmtExprMutator::VisitStmt_(op); + return store; } - /*! - * \brief Rewrite accesses to the producer buffers after they are resized for the pipeline. - * \param buffer_var The buffer variable. - * \param index The index of he buffer access. - * \return The updated index for accessing the resized buffer. - */ - PrimExpr RewriteProducerBufferAccess(const Var& buffer_var, const PrimExpr& index) { - const auto& extents = buffer_info_.at(buffer_var).allocate->extents; - ICHECK(extents.size() == 1U); - PrimExpr stride = extents[0]; - return indexmod(current_pipeline_->loop_var, current_pipeline_->num_stages) * stride + index; + PrimExpr VisitExpr_(const BufferLoadNode* op) { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_map_.find(op->buffer); + if (it != buffer_map_.end()) { + auto* n = load.CopyOnWrite(); + n->buffer = (*it).second.new_buffer; + n->indices.insert(n->indices.begin(), + indexmod(buffer_map_.at(op->buffer).loop_var, n->buffer->shape[0])); + } + return load; } - Stmt VisitStmt_(const StoreNode* op) { - Store store = Downcast(StmtExprMutator::VisitStmt_(op)); - if (current_pipeline_ && current_pipeline_->producer_buffers.count(store->buffer_var)) { - PrimExpr new_index = RewriteProducerBufferAccess(store->buffer_var, store->index); - store = Store(store->buffer_var, store->value, new_index, store->predicate); + BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) { + auto it = buffer_map_.find(buffer_region->buffer); + if (it != buffer_map_.end()) { + Region new_region = buffer_region->region; + new_region.insert(new_region.begin(), + Range::FromMinExtent(0, (*it).second.new_buffer->shape[0])); + return BufferRegion((*it).second.new_buffer, new_region); } - return store; + return buffer_region; } - PrimExpr VisitExpr_(const LoadNode* op) { - Load load = Downcast(StmtExprMutator::VisitExpr_(op)); - if (current_pipeline_ && current_pipeline_->producer_buffers.count(load->buffer_var)) { - PrimExpr new_index = RewriteProducerBufferAccess(load->buffer_var, load->index); - load = Load(load->dtype, load->buffer_var, new_index, load->predicate); + Stmt VisitStmt_(const BlockNode* op) { + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); } - return load; + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + auto* n = block.CopyOnWrite(); + n->reads.MutateByApply( + std::bind(&PipelineInjector::RewritePipelineBufferRegion, this, std::placeholders::_1)); + n->writes.MutateByApply( + std::bind(&PipelineInjector::RewritePipelineBufferRegion, this, std::placeholders::_1)); + + return std::move(block); } PrimExpr VisitExpr_(const CallNode* op) { @@ -629,20 +521,19 @@ class PipelineInjector : public StmtExprMutator { CHECK(pipeline_var_.defined()) << "ValueError: intrinsic tvm_get_pipeline can only be called inside the pipeline scope."; return pipeline_var_.value(); - } else if (call->op.same_as(builtin::tvm_access_ptr())) { - ICHECK(call->args.size() == 5U); - Var buffer_var = Downcast(call->args[1]); - if (current_pipeline_ && current_pipeline_->producer_buffers.count(buffer_var)) { - PrimExpr elem_offset = call->args[2]; - elem_offset = RewriteProducerBufferAccess(buffer_var, elem_offset); - Array new_args(call->args); - new_args.Set(2, elem_offset); - return Call(call->dtype, call->op, new_args); - } } return call; } + Stmt VisitStmt_(const ForNode* op) { + auto it = op->annotations.find(attr::pipeline_scope); + if (it != op->annotations.end()) { + return InjectPipeline(op); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + void DetectNativePipeline() { if (!use_native_pipeline_) { return; @@ -662,15 +553,6 @@ class PipelineInjector : public StmtExprMutator { } } - // Information of the current pipeline. - const PipelineInfo* current_pipeline_ = nullptr; - // A map from annotated pipeline statements to the information for the transformation. - const SMap& pipeline_info_; - // A map from buffer variables to their information. - const SMap& buffer_info_; - // Buffer allocations that need to be skipped as they will be regenerated by the pipeline - // transformation. - SSet skip_allocs; // Whether the native pipeline is enabled. bool use_native_pipeline_; // The pipeline object if native pipeline is enabled. @@ -693,7 +575,7 @@ Pass InjectSoftwarePipeline() { cfg = AttrsWithDefaultValues(); } fptr->body = inject_software_pipeline::PipelineInjector::Inject( - cfg.value()->use_native_pipeline, std::move(fptr->body)); + cfg.value()->use_native_pipeline, f); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index f35725c1e4..53bdcc26cb 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -23,6 +23,7 @@ */ #include "ir_utils.h" +#include #include #include @@ -210,5 +211,51 @@ String GetPtrStorageScope(Var buffer_var) { return ptr_type->storage_scope; } +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(indices.size(), target->shape.size()); + + arith::Analyzer analyzer; + Array result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - indices.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& range = source->region[i]; + ICHECK(analyzer.CanProve(range->extent == 1)); + result.push_back(range->min); + } + for (size_t i = 0; i < indices.size(); ++i) { + const Range& range = source->region[i + offset]; + const PrimExpr& index = indices[i]; + result.push_back(range->min + index); + } + return result; +} + +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(region.size(), target->shape.size()); + + arith::Analyzer analyzer; + Region result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - region.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& source_range = source->region[i]; + ICHECK(analyzer.CanProve(source_range->extent == 1)); + result.push_back(Range::FromMinExtent(source_range->min, 1)); + } + for (size_t i = 0; i < region.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const Range& target_range = region[i]; + result.push_back( + Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); + } + return result; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index b5a154b707..79c5f06092 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -197,6 +197,22 @@ Stmt ConvertSSA(Stmt stmt); * \return A string representing the storage scope of this buffer variable. */ String GetPtrStorageScope(Var buffer_var); + +/*! + * \brief Convert match buffer target buffer access indices to original one. + * \param indices The indices of the target buffer + * \return The indices of source buffer. + */ +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices); + +/*! + * \brief Convert match buffer target buffer region to original one. + * \param region The sub-region of the target buffer + * \return The region of source buffer. + */ +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/lower_logical_intrin.cc b/src/tir/transforms/lower_logical_intrin.cc new file mode 100644 index 0000000000..70784be697 --- /dev/null +++ b/src/tir/transforms/lower_logical_intrin.cc @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Lower logical intrinsics + * \file lower_logical_intrin.cc + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +struct LogicalIntrinRegistry { + static Map registry; +}; + +class LogicalIntrinBufferReplacer : public StmtExprMutator { + public: + explicit LogicalIntrinBufferReplacer(Map buffer_var_to_new_buffer) + : buffer_var_to_new_buffer_(std::move(buffer_var_to_new_buffer)) { + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = buffer_var_to_new_buffer_.find(GetRef(op)); + if (it != buffer_var_to_new_buffer_.end()) { + return (*it).second->data; + } + return GetRef(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_var_to_new_buffer_.find(load->buffer->data); + if (it != buffer_var_to_new_buffer_.end()) { + auto* n = load.CopyOnWrite(); + n->buffer = (*it).second; + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_var_to_new_buffer_.find(store->buffer->data); + if (it != buffer_var_to_new_buffer_.end()) { + auto* n = store.CopyOnWrite(); + n->buffer = (*it).second; + } + return store; + } + + private: + Map buffer_var_to_new_buffer_; +}; + +class LogicalIntrinMutator : public StmtMutator { + public: + using FLowerLogicalIntrin = runtime::TypedPackedFunc; + + explicit LogicalIntrinMutator(const PrimFunc& func) { + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + } + + Stmt VisitStmt_(const BlockNode* op) { + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const EvaluateNode* op) { + static const auto& f_lower_logical_intrin = Op::GetAttrMap("FLowerLogicalIntrin"); + if (const auto* call = op->value.as()) { + if (const auto* call_op = call->op.as()) { + PrimFunc intrin_impl = f_lower_logical_intrin.get(GetRef(call_op), NullValue()); + if (intrin_impl.defined()) { + // Make inlined call to intrin_impl + CHECK(intrin_impl->params.size() == call->args.size()); + Map subst_map; + for (size_t i = 0; i < call->args.size(); i++) { + subst_map.Set(intrin_impl->params[i], call->args[i]); + } + Map new_buffer_map; + for (size_t i = 0; i < call->args.size(); i++) { + const auto& param = intrin_impl->params[i]; + if (const auto* var = param.as()) { + if (var->dtype.is_handle()) { + Var buffer_var = Downcast(param); + auto it = intrin_impl->buffer_map.find(buffer_var); + CHECK(it != intrin_impl->buffer_map.end()) << buffer_var; + if (it != intrin_impl->buffer_map.end()) { + new_buffer_map.Set((*it).second->data, + buffer_data_to_buffer_.at(Downcast(call->args[i]))); + } + } + } + } + + auto body = Substitute(intrin_impl->body, subst_map); + return LogicalIntrinBufferReplacer(new_buffer_map)(body); + } + } + } + return StmtMutator::VisitStmt_(op); + } + + Map buffer_data_to_buffer_; +}; + +namespace transform { + +Pass LowerLogicalIntrin() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = LogicalIntrinMutator(f)(std::move(f->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerLogicalLayout", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerLogicalIntrin").set_body_typed(LowerLogicalIntrin); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index fd20a58bf9..2f8fbe0ea6 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -19,6 +19,7 @@ /*! * \file lower_match_buffer.cc + * \brief The pass for lowering match_buffer. */ #include @@ -28,6 +29,7 @@ #include #include "../ir/functor_common.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -36,7 +38,7 @@ class MatchBufferLower : public StmtExprMutator { explicit MatchBufferLower(const PrimFunc& func) { for (const Var& param : func->params) { // Mark input var as const variable. - if (!param.dtype().is_handle()) var_map_[param] = param; + if (!param.dtype().is_handle()) var_map_.Set(param, param); } } @@ -74,7 +76,7 @@ class MatchBufferLower : public StmtExprMutator { Var v = GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { - return it->second; + return (*it).second; } else { return std::move(v); } @@ -89,11 +91,11 @@ class MatchBufferLower : public StmtExprMutator { if (it == match_buffers_.end()) { return stmt; } else { - const Buffer& buffer = it->first; - const BufferRegion& source = it->second; + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; auto n = CopyOnWrite(op); - n->indices = MatchBufferRegion(buffer, source).ConvertIndices(op->indices); + n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; return Stmt(n); } @@ -108,21 +110,35 @@ class MatchBufferLower : public StmtExprMutator { if (it == match_buffers_.end()) { return expr; } else { - const Buffer& buffer = it->first; - const BufferRegion& source = it->second; - Array indices = MatchBufferRegion(buffer, source).ConvertIndices(op->indices); + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); return BufferLoad(source->buffer, indices); } } + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Load from buffer created by match_buffer is not allowed, but got: " << expr; + return expr; + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Store from buffer created by match_buffer is not allowed, but got: " << stmt; + return stmt; + } + BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { const Buffer& buffer = buffer_region->buffer; auto it = match_buffers_.find(buffer); if (it == match_buffers_.end()) { return buffer_region; } else { - const BufferRegion& source = it->second; - Region region = MatchBufferRegion(buffer, source).ConvertRegion(buffer_region->region); + const BufferRegion& source = (*it).second; + Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); return BufferRegion(source->buffer, std::move(region)); } } @@ -156,7 +172,7 @@ class MatchBufferLower : public StmtExprMutator { } // Step.2. Update - match_buffers_[buffer] = source; + match_buffers_.Set(buffer, source); // Step.2.1. Update buffer data Bind(buffer->data, source_buffer->data, buffer->name + ".data"); @@ -206,10 +222,10 @@ class MatchBufferLower : public StmtExprMutator { Var v = Downcast(arg); auto it = var_map_.find(v); if (it == var_map_.end()) { - var_map_[v] = value; + var_map_.Set(v, value); analyzer_.Bind(v, value); } else { - AssertBinding(it->second, value, arg_name); + AssertBinding((*it).second, value, arg_name); } } else { AssertBinding(arg, value, arg_name); @@ -224,9 +240,9 @@ class MatchBufferLower : public StmtExprMutator { private: /*! \brief Buffer region mapping. */ - std::unordered_map match_buffers_; + Map match_buffers_; /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */ - std::unordered_map var_map_; + Map var_map_; /*! \brief The analyzer */ arith::Analyzer analyzer_; }; @@ -251,4 +267,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB } // namespace transform } // namespace tir -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 5fbfadf721..cdf6030d20 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -241,7 +241,8 @@ class WarpAccessRewriter : protected StmtExprMutator { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); - return Store(op->buffer_var, op->value, local_index, op->predicate); + PrimExpr new_value = VisitExpr(op->value); + return Store(op->buffer_var, new_value, local_index, op->predicate); } else { return StmtExprMutator::VisitStmt_(op); } @@ -256,6 +257,9 @@ class WarpAccessRewriter : protected StmtExprMutator { << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); + if (analyzer_->CanProveEqual(group, warp_index_)) { + return load_value; + } PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), {mask, load_value, group, width_, warp_size_}); diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 949c955b2d..e929561200 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -73,8 +73,6 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { ICHECK(!op->init.defined()); - bool is_root = is_root_; - is_root_ = false; Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { @@ -83,11 +81,23 @@ class BufferAllocationLocator : public StmtExprMutator { buffer_data_to_buffer_.Set(buf->data, buf); } } + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + const Var& source_var = match_buffer->source->buffer->data; + ICHECK(buffer_data_to_buffer_.count(source_var)); + buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); + } Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - // Ignore buffer allocated inside the block when getting access region. + // No longer consider buffers created by match_buffer inside the block when updating access + // region. + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + buffer_data_to_buffer_.erase(target_var); + } + // No longer consider buffers allocated inside the block when updating access region. if (it != alloc_buffers_.end()) { for (const Buffer& buf : it->second) { buffer_data_to_buffer_.erase(buf->data); @@ -96,12 +106,9 @@ class BufferAllocationLocator : public StmtExprMutator { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); - // The read/write regions of root block are always empty. - if (!is_root) { - // Recalculate block access region - CollectReadWrite(GetRef(op), &n->reads, &n->writes); - } - + // Erase buffer allocated inside the block from access region. + n->reads = RemoveRedundantBufferRegion(n->reads); + n->writes = RemoveRedundantBufferRegion(n->writes); return Stmt(n); } @@ -120,28 +127,28 @@ class BufferAllocationLocator : public StmtExprMutator { /*init=*/NullOpt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); - CollectReadWrite(opaque_block, &n->reads, &n->writes); + Array> access = + GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); + n->reads = access[0]; + n->writes = access[1]; BlockRealize realize({}, Bool(true), Block(n)); return std::move(realize); } - void CollectReadWrite(const Block& block, Array* reads, - Array* writes) { - Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - *reads = access[0]; - *writes = access[1]; - for (const auto& opaque_access : access[2]) { - reads->push_back(opaque_access); - writes->push_back(opaque_access); + Array RemoveRedundantBufferRegion(const Array& region) const { + Array result; + for (const BufferRegion& buffer_region : region) { + if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { + result.push_back(buffer_region); + } } + return result; } /*! \brief The map from stmt to the buffers to be allocated under it. */ std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ Map buffer_data_to_buffer_; - /*! \brief indicate the whether the block is root. */ - bool is_root_{true}; }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { diff --git a/tests/python/integration/test_logical_layout_rocm_tir.py b/tests/python/integration/test_logical_layout_rocm_tir.py new file mode 100644 index 0000000000..1b955df525 --- /dev/null +++ b/tests/python/integration/test_logical_layout_rocm_tir.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Test workload for lowering and build""" +import tvm +from tvm import tir +from tvm.script import ty +import tvm.testing +import numpy as np + +from tvm.ir import register_intrin_lowering, register_op_attr + + +@tvm.script.tir +def tvm_mfma_sync(d: ty.handle, index_d: ty.int32, a: ty.handle, index_a: ty.int32, b: ty.handle, index_b: ty.int32, c:ty.handle, index_c: ty.int32) -> None: + tx = tir.env_thread("threadIdx.x") + tir.launch_thread(tx, 64) + num_warp_i = tir.var('int32') + num_warp_j = tir.var('int32') + wmma_A1 = tir.match_buffer(a, [64, num_warp_i, 16], dtype='float16') + wmma_B1 = tir.match_buffer(b, [64, num_warp_j, 16], dtype='float16') + wmma_C1 = tir.match_buffer(c, [64, num_warp_i, num_warp_j, 16], dtype='float32') + wmma_D1 = tir.match_buffer(d, [64, num_warp_i, num_warp_j, 16], dtype='float32') + + with tir.block([], 'mfma_sync'): + tir.reads([wmma_A1[0:64,0:num_warp_i,0:16], wmma_B1[0:64,0:num_warp_j,0:16], wmma_C1[0:64,0:num_warp_i,0:num_warp_j,0:16]]) + tir.writes([wmma_D1[0:64,0:num_warp_i,0:num_warp_j,0:16]]) + wmma_D1[tx, index_a, index_b, tir.ramp(0, 1, 4)] = tir.call_llvm_pure_intrin(tir.llvm_lookup_intrinsic_id('llvm.amdgcn.mfma.f32.16x16x16f16'), 6, + wmma_B1[tx, index_b, tir.ramp(0, 1, 4)], wmma_A1[tx, index_a, tir.ramp(0, 1, 4)], wmma_C1[tx, index_a, index_b, tir.ramp(0, 1, 4)], 0, 0, 0, dtype="float32x4") + +register_op_attr("tir.tvm_mfma_sync", "FLowerLogicalIntrin", tvm_mfma_sync); + +def lower_mfma_16x16x16_matrix_a(i, j): + return tir.floordiv(j, 4) * 16 + tir.floormod(i, 16), tir.floordiv(i, 16), tir.floormod(j, 4) + +def lower_mfma_16x16x16_matrix_c(i, j): + return tir.floordiv(tir.floormod(j, 16), 4) * 16 + tir.floormod(i, 16), tir.floordiv(i, 16), tir.floordiv(j, 16), tir.floormod(j, 4) + +tir.LogicalLayout.register('warp.mfma_16x16x16_matrix_a', lower_mfma_16x16x16_matrix_a) +tir.LogicalLayout.register('warp.mfma_16x16x16_matrix_c', lower_mfma_16x16x16_matrix_c) + + +@tvm.script.tir +def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # match buffer + A = tir.match_buffer(a, [1024, 1024], "float16") + B = tir.match_buffer(b, [1024, 1024], "float16") + C = tir.match_buffer(c, [1024, 1024], "float32") + + # body + for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): + for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): + with tir.block([16, 8]) as [bx, by]: + tir.bind(bx, blockIdx_x) + tir.bind(by, blockIdx_y) + shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="warp.mfma_16x16x16_matrix_a") + wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="warp.mfma_16x16x16_matrix_a") + wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="warp.mfma_16x16x16_matrix_c") + for ty in tir.thread_binding(0, 2, "threadIdx.y"): + for tz in tir.thread_binding(0, 2, "threadIdx.z"): + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads([]) + tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + C0 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="warp.mfma_16x16x16_matrix_c", + offset_factor=1, + ) + for i1, j1 in tir.grid(16, 16): + with tir.block([16, 16]) as [vii, vjj]: + tir.bind(vii, i1) + tir.bind(vjj, j1) + tir.reads([]) + tir.writes(C0[vii : vii + 1, vjj : vjj + 1]) + + C0[vii, vjj] = tir.float32(0) + + for ko in range(0, 32): + # copy data from global to shared + for tx in tir.thread_binding(0, 64, "threadIdx.x"): + for i0, j0 in tir.grid(1, 2): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, bx * 64 + tx + i0) + tir.bind(vj, ko * 32 + tz * 16 + ty * 8 + j0 * 4 + j1) + shared_A[vi, vj] = A[vi, vj] + + for i0, j0 in tir.grid(1, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, by * 128 + ty * 64 + tx + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_B[vi, vj] = B[vi, vj] + + for ki in range(0, 2): + for i in range(0, 2): + with tir.block([64, 64]) as [vi, vk]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16, + ] + ) + tir.writes( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + A0 = tir.match_buffer( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16, + ], + (16, 16), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_A0 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="warp.mfma_16x16x16_matrix_a", + offset_factor=1, + ) + + for i1, j1 in tir.grid(16, 16): + with tir.block([16, 16]) as [vii, vjj]: + tir.bind(vii, i1) + tir.bind(vjj, j1) + tir.reads([A0[vii, vjj]]) + tir.writes(wmma_A0[vii, vjj]) + wmma_A0[vii, vjj] = A0[vii, vjj] + + for j in range(0, 4): + with tir.block([64, 64]) as [vj, vk]: + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16, + ] + ) + tir.writes( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + B0 = tir.match_buffer( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16, + ], + (16, 16), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_B0 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="warp.mfma_16x16x16_matrix_a", + offset_factor=1, + ) + + for i1, j1 in tir.grid(16, 16): + with tir.block([16, 16]) as [vii, vjj]: + tir.bind(vii, i1) + tir.bind(vjj, j1) + tir.reads([B0[vii, vjj]]) + tir.writes([wmma_B0[vii, vjj]]) + wmma_B0[vii, vjj] = B0[vii, vjj] + + for i, j in tir.grid(2, 4): + with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ + vi, + vj, + vk, + ]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + [ + wmma_A[ + vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_B[ + vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_C[ + vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 + ], + ] + ) + tir.writes( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] + ) + wmma_A1 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="warp.mfma_16x16x16_matrix_a", + offset_factor=1, + ) + wmma_B1 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="warp.mfma_16x16x16_matrix_a", + offset_factor=1, + ) + wmma_C1 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="warp.mfma_16x16x16_matrix_c", + offset_factor=1, + ) + + tir.evaluate( + tir.tvm_mfma_sync( + wmma_C1.data, + i * 4 + j, + wmma_A1.data, + i, + wmma_B1.data, + j, + wmma_C1.data, + i * 4 + j, + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + s0 = tir.var("int32") + s1 = tir.var("int32") + wmma_C2 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="warp.mfma_16x16x16_matrix_c", + offset_factor=1, + ) + C1 = tir.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[s0, s1], + offset_factor=1, + ) + + for tx in tir.thread_binding(0, 64, "threadIdx.x"): + for r in range(0, 4): + with tir.block([16, 16]) as [vii, vjj]: + tir.bind(vii, tir.floormod(tx, 16)) + tir.bind(vjj, tir.floordiv(tx, 16) * 4 + r) + tir.reads([wmma_C2[vii, vjj]]) + tir.writes([C1[vii, vjj]]) + C1[vii, vjj] = wmma_C2[vii, vjj] + + +def test_gemm_tensorcore(): + dev = tvm.device("rocm", 0) + print(tvm.script.asscript(tvm.lower(tensorcore_gemm, simple_mode=True))) + f = tvm.build(tensorcore_gemm, target="rocm", name="dense") + a_np = np.random.uniform(size=(1024, 1024)).astype("float16") + + b_np = np.random.uniform(size=(1024, 1024)).astype("float16") + + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) + f(a, b, c) + cc = c.numpy() + np.set_printoptions(threshold=np.inf) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + evaluator = f.time_evaluator(f.entry_name, dev, number=100) + t = evaluator(a, b, c).mean + num_flops = 2 * 1024 * 1024 * 1024 + gflops = num_flops / (t * 1e3) / 1e6 + print("gemm with tensor core: %f ms" % (t * 1e3)) + print("GFLOPS: %f" % gflops) + + +if __name__ == "__main__": + test_gemm_tensorcore() diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 9053b35348..3fa4795870 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -302,6 +302,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: ) +@tvm.testing.requires_cuda def test_gemm_tensorcore(): dev = tvm.device("cuda", 0) a_np = np.random.uniform(size=(1024, 1024)).astype("float16") @@ -310,7 +311,6 @@ def test_gemm_tensorcore(): a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) - print(tvm.script.asscript(tvm.lower(tensorcore_gemm, simple_mode=True))) f = tvm.build(tensorcore_gemm, target="cuda", name="dense") f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) @@ -324,4 +324,4 @@ def test_gemm_tensorcore(): if __name__ == "__main__": - test_gemm_tensorcore() \ No newline at end of file + test_gemm_tensorcore() diff --git a/tests/python/tir/conv_tensorcore_demo.py b/tests/python/tir/conv_tensorcore_demo.py index 37d1b1e43a..2229ddb1a5 100644 --- a/tests/python/tir/conv_tensorcore_demo.py +++ b/tests/python/tir/conv_tensorcore_demo.py @@ -269,10 +269,10 @@ def test_tensorcore(): WF = s.cache_read(Conv, 2, "wmma.matrix_b") ConvF = s.cache_write(Conv, 0, "wmma.accumulator") - block_row_warps = 1 - block_col_warps = 1 - warp_row_tiles = 1 - warp_col_tiles = 1 + block_row_warps = 4 + block_col_warps = 2 + warp_row_tiles = 2 + warp_col_tiles = 4 warp_size = 32 chunk = 2 @@ -291,9 +291,9 @@ def test_tensorcore(): # Schedule local computation s.compute_at(ConvF, oc) - ic, kh, kw, _nnf, _oof, ii = s.get_loops(ConvF)[-6:] + no, oo, ic, kh, kw = s.get_loops(ConvF)[-8:-3] ko, ki = s.split(ic, [None, chunk]) - s.reorder(ko, kh, ki) + s.reorder(ko, kh, ki, kw, no, oo) # Move intermediate computation into each output compute tile s.compute_at(AF, kw) @@ -301,20 +301,22 @@ def test_tensorcore(): # Schedule for A's share memory s.compute_at(AS, kh) - _, _, nn, ii = s.get_loops(AS)[-4:] + n, _, _, nn, ii = s.get_loops(AS)[-5:] + ty, tz = s.split(n, [block_row_warps, block_col_warps]) t = s.fuse(nn, ii) _, ti = s.split(t, [None, warp_size]) s.bind(ti, "threadIdx.x") + s.bind(ty, "threadIdx.y") + s.bind(tz, "threadIdx.z") # Schedule for W's share memory s.compute_at(WS, kh) - kw, ic, o, ii, oo = s.get_loops(WS)[-5:] - tx, xo = s.split(o, [block_row_warps, None]) - ty, _ = s.split(xo, [block_col_warps, None]) # pylint: disable=redefined-outer-name + _, _, o, ii, oo = s.get_loops(WS)[-5:] + ty, tz = s.split(o, [block_row_warps, block_col_warps]) t = s.fuse(ii, oo) to, ti = s.split(t, [warp_size, None]) - s.bind(tx, "threadIdx.y") - s.bind(ty, "threadIdx.z") + s.bind(ty, "threadIdx.y") + s.bind(tz, "threadIdx.z") s.bind(to, "threadIdx.x") s.vectorize(ti) diff --git a/tests/python/tir/test_schedule_software_pipeline.py b/tests/python/tir/test_schedule_software_pipeline.py index 672f54ee20..81ac1dbcc0 100644 --- a/tests/python/tir/test_schedule_software_pipeline.py +++ b/tests/python/tir/test_schedule_software_pipeline.py @@ -42,8 +42,54 @@ def test_cuda_pipeline(): c = tvm.nd.array(np.zeros((128, 128), dtype="float32"), device=dev) f(a, b, c) c_np = np.matmul(a_np, b_np.T) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_cuda_nested_pipeline(): + device = "cuda" + dev = tvm.device(device, 0) + if not dev.exist: + print("Skip because %s is not enabled" % device) + return + + s = tir.Schedule(util.matmul_stmt()) + C = s.get_block("update") + i, j, k = s.get_loops(C) + io, ii = s.split(i, factors=[8, 16]) + jo, ji = s.split(j, factors=[8, 16]) + ko, km, ki = s.split(k, factors=[4, 2, 16]) + s.reorder(io, jo, ko, km, ii, ji, ki) + s.bind(io, "blockIdx.x") + s.bind(jo, "threadIdx.x") + + A_local = s.cache_read(C, 1, "local") + B_local = s.cache_read(C, 2, "local") + A_shared = s.cache_read(A_local, 0, "shared") + B_shared = s.cache_read(B_local, 0, "shared") + + s.compute_at(A_local, km) + s.compute_at(B_local, km) + s.compute_at(A_shared, ko) + s.compute_at(B_shared, ko) + + for load in [A_shared, B_shared]: + _, tt = s.split(s.fuse(s.get_loops(load)[-2]), factors=[None, 8]) + s.bind(tt, "threadIdx.x") + s.software_pipeline(km, 2) + s.software_pipeline(ko, 2) + + f = tvm.build(s.mod["main"], None, target="cuda") + cuda_code = f.imported_modules[0].get_source() + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np, device=dev) + b = tvm.nd.array(b_np, device=dev) + c = tvm.nd.array(np.zeros((128, 128), dtype="float32"), device=dev) + f(a, b, c) + c_np = np.matmul(a_np, b_np.T) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) if __name__ == "__main__": - test_cuda_pipeline() \ No newline at end of file + test_cuda_pipeline() + test_cuda_nested_pipeline() diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index c7cf2f6edf..8c2b2710f1 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -77,12 +77,11 @@ def match_buffer_func(a: ty.handle, b: ty.handle) -> None: with tir.block([8, 8], "block") as [vi, vj]: tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) with tir.block([16, 16], "AAA") as [i, j]: - AAA = tir.match_buffer(AA[i, j], ()) - AAA[()] = 1.0 + AA = tir.match_buffer(A[i, j], ()) + AA[()] = 1.0 tir.evaluate(B0.data) tir.evaluate(B1.data) diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 70de437280..bc421aa4d1 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import tir, script from tvm.ir import Range @@ -62,6 +63,39 @@ def match_buffer_func() -> None: tir.evaluate(B1.data) +@tvm.script.tir +def opaque_block_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((16, 16), "float32") + B = tir.alloc_buffer((16, 16), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes([B[i, 0:16]]) + for j in range(0, 16): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + + +@tvm.script.tir +def opaque_access_func() -> None: + A = tir.alloc_buffer([1024]) + B = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [v]: + tir.bind(v, i) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([B[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") + ) + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -76,25 +110,64 @@ def test_block_access_region_detector(): ) +def test_opaque_block(): + alloc_buffers = opaque_block_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + block0 = opaque_block_func.body.block.body.body.block + ret = tir.analysis.get_block_access_region(block0, buffer_var_map) + tvm.ir.assert_structural_equal(block0.reads, ret[0]) + tvm.ir.assert_structural_equal(block0.writes, ret[1]) + + block1 = block0.body.body.block + ret = tir.analysis.get_block_access_region(block1, buffer_var_map) + tvm.ir.assert_structural_equal(block1.reads, ret[0]) + tvm.ir.assert_structural_equal(block1.writes, ret[1]) + + +def test_opaque_access(): + block = opaque_access_func.body.block.body.body.block + alloc_buffers = opaque_access_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block block_inner = block.body[0].body.body.block - alloc_buffers = func.body.block.alloc_buffers + alloc_buffers = match_buffer_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - # Check inner block AAA - ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) - tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) - tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) - # Check block ret = tir.analysis.get_block_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.writes, ret[1]) # B is opaque access tvm.ir.assert_structural_equal(block.reads, ret[2]) + # Check inner block AAA without updating buffer_var_map + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + # Since AA is not in the buffer_var_map, region of AA will not be collected. + tvm.ir.assert_structural_equal([], ret[1]) + + # Check inner block AAA + for match_buffer in block.match_buffers: + target_buffer = match_buffer.buffer + buffer_var_map[target_buffer.data] = target_buffer + + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + if __name__ == "__main__": test_block_access_region_detector() + test_opaque_block() + test_opaque_access() test_match_buffer() diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 0c297820a8..78a8c51178 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -65,8 +65,8 @@ def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] -@tvm.ir.register_op_attr("tir.test_intrin", "") -def test_intrin(data, elem_offset, stride_0, stride_1, shape_0, shape_1): +@tvm.ir.register_op_attr("tir.intrin_test", "") +def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): return 0 @@ -85,7 +85,7 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: offset_factor=1, ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -108,7 +108,7 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: offset_factor=1, ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -129,7 +129,7 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: tir.reads([]) tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) tir.evaluate( - tir.test_intrin( + tir.intrin_test( A.data, i * 131072 + j * 128 + k * 16, 8192, @@ -144,7 +144,7 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: tir.reads([]) tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) tir.evaluate( - tir.test_intrin( + tir.intrin_test( B.data, i * 4096 + j * 2048 + k * 8, 64, @@ -205,7 +205,7 @@ def recursive_match(a: ty.handle, b: ty.handle) -> None: offset_factor=1, ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], @@ -250,7 +250,7 @@ def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: ] ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( A.data, i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, 64, @@ -282,7 +282,7 @@ def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None sub_A[ii, jj] = 1 for j in range(0, 4): tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -306,7 +306,7 @@ def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.in A[i * m + ii, jj] = 1 for j in range(0, 4): tir.evaluate( - tir.test_intrin( + tir.intrin_test( B.data, i * n * (m * 4), m * 4, @@ -330,7 +330,7 @@ def rank0_buffer(a: ty.handle, b: ty.handle) -> None: sub_B = tir.match_buffer(B[i, j], (), offset_factor=1) sub_A[()] = 1 tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_B.data, sub_B.elem_offset, 0, @@ -352,7 +352,7 @@ def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: tir.writes([A[i, j], B[i, j]]) A[i, j] = 1 tir.evaluate( - tir.test_intrin( + tir.intrin_test( B.data, i * 8 + j, 0, @@ -365,13 +365,33 @@ def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: @tvm.script.tir -def fail_match_buffer(a: ty.handle) -> None: +def fail_match_load(a: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) - for i, j in tir.grid(8, 2): + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes([]) + sub_A = tir.match_buffer(A[i, j], ()) + tir.evaluate(tir.load("float32", sub_A.data, 0)) + + +@tvm.script.tir +def fail_match_store(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): with tir.block([]): tir.reads([]) + tir.writes(A[i, j]) + sub_A = tir.match_buffer(A[i, j], ()) + sub_A.data[0] = 1 + + +@tvm.script.tir +def fail_buffer_bind(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): stride = tir.var("int32") - tir.writes(A[i, j * 4 : j * 4 + 4]) sub_A = tir.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 ) @@ -384,8 +404,6 @@ def fail_match_func_param(a: ty.handle, m: ty.handle, n: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) for i, j in tir.grid(8, 2): with tir.block([]): - tir.reads([]) - tir.writes(A[i, j * 4 : j * 4 + 4]) sub_A = tir.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1 ) @@ -413,8 +431,13 @@ def test_rank0_buffer(): _check(rank0_buffer, transformed_rank0_buffer) +def test_fail_load_store(): + _check_fail(fail_match_load) + _check_fail(fail_match_store) + + def test_fail_buffer_bind(): - _check_fail(fail_match_buffer) + _check_fail(fail_buffer_bind) def test_fail_match_func_param(): @@ -427,5 +450,6 @@ def test_fail_match_func_param(): test_recursive_match() test_symbolic_match() test_rank0_buffer() + test_fail_load_store() test_fail_buffer_bind() test_fail_match_func_param() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 07a82ba993..dbae0b6fa5 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -398,7 +398,7 @@ def test_block_blockrealize(): ) ] writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] - match_buffer_region = tvm.tir.MatchBufferRegion( + block_match_buffer = tvm.tir.MatchBufferRegion( match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) ) @@ -410,7 +410,7 @@ def test_block_blockrealize(): body, init=init_body, alloc_buffers=[alloc_buffer], - match_buffers=[match_buffer_region], + match_buffers=[block_match_buffer], annotations={"attr_key": "attr_value"}, ) @@ -462,7 +462,7 @@ def test_block_blockrealize(): assert output.find("reads") != -1 assert output.find("writes") != -1 assert output.find("alloc_buffer") != -1 - assert output.find("match_buffer_region") != -1 + assert output.find("match_buffer") != -1 assert output.find("attr") != -1 assert output.find("with init()") != -1 @@ -471,7 +471,6 @@ def test_block_blockrealize(): test_intimm_cond() test_buffer_load_store() test_vars() - test_scoped_storage_var() test_prim_func() test_cast() test_attr() diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 99f2b476b0..f7d8fad527 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -17,7 +17,6 @@ # pylint: disable=missing-function-docstring,missing-module-docstring import sys -import numpy as np import pytest import tvm import tvm.testing diff --git a/tests/python/unittest/test_tir_sparse.py b/tests/python/unittest/test_tir_sparse.py new file mode 100644 index 0000000000..74845b03e5 --- /dev/null +++ b/tests/python/unittest/test_tir_sparse.py @@ -0,0 +1,84 @@ +import tvm +from tvm import tir +from tvm.script import ty + + +@tvm.script.tir +def spmm_tir(a_indptr: ty.handle, a_indices: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: + m = tir.var('int32') + n = tir.var('int32') + k = tir.var('int32') + nnz = tir.var('int32') + A_indptr = tir.match_buffer(a_indptr, [m + 1], 'int32') + A_indices = tir.match_buffer(a_indices, [nnz], 'int32') + A = tir.match_buffer(a_data, [nnz], 'float32') + B = tir.match_buffer(b, [k, n], 'float32') + C = tir.match_buffer(c, [m, n], 'float32') + with tir.block([m, n], 'spmm_outer') as [vi, vj]: + with tir.init(): + C[vi, vj] = 0. + with tir.block([tir.reduce_axis(A_indptr[vi], A_indptr[vi + 1])], 'spmm_inner') as [vk]: + C[vi, vj] = C[vi, vj] + A[vk] * B[A_indices[vk], vj] + + +""" +@tvm.script.tir +def spmm_sparse_tir(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + m = tir.var('int32') + n = tir.var('int32') + k = tir.var('int32') + A = tir.sp.match_buffer(a, [m, k], ['dense', ('sparse', None)], 'int32', 'float32') + B = tir.match_buffer(b, [k, n], 'float32') + C = tir.match_buffer(c, [m, n], 'float32') + for i, j in tir.grid(m, n): + for k in tir.sp.iter(A[i]): + with tir.block([m, n, tir.reduce_axis(0, k)], 'spmm') as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + with tir.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] +""" + + +@tvm.script.tir +def embedding_update(a: ty.handle, grad_out: ty.handle, grad_in: ty.handle) -> None: + m = tir.var('int32') # number of tokens + n = tir.var('int32') # feature size + k = tir.var('int32') # dictionary size + A = tir.sp.match_buffer(a, [m, k], ['dense', ('sparse', None)], 'int32', 'float32') + B = tir.match_buffer(grad_out, [m, n], 'float32') + C = tir.match_buffer(grad_in, [k, n], 'float32') + for i, j in tir.grid(m, n): + for k in tir.sp.iter(A[i]): + with tir.block([m, n, k]) as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + C[vk, vj] = C[vk, vj] + A[vi, vk] * B[vi, vj] + + +""" +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-f16 +@tvm.script.tir +def sparse_tensor_core_desr(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.sp.match_buffer(a, [16, 4, 4], ['dense', 'dense', ('sparse', 2)], 'int8', 'float16', align=128, offset_factor=256, scope='wmma.matrix_a') + B = tir.match_buffer(b, [16, 16], 'float16', align=128, offset_factor=256, scope='wmma.matrix_b') + C = tir.match_buffer(c, [16, 16], 'float32', align=128, offset_factor=256, scope='wmma.matrix_accumulator') + + for i, j, k in tir.grid(16, 16, 4): + for l in tir.sp.iter(A[i, j]): + with tir.block([16, 16, tir.reduce_axis(0, 4), tir.reduce_axis(0, 4)], "root") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + with tir.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk, vl] * B[vk * 4 + vl, vj] +""" + + +if __name__ == "__main__": + pass diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 47cf06e567..d834fbeaa7 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -337,6 +337,52 @@ def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None: C[i, j] = B[0, j] * 2.0 +@tvm.script.tir +def match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((16, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[i, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[i, j], ()) + C1[()] = B2[()] * 2.0 + + +@tvm.script.tir +def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((1, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[0, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[0, j], ()) + C1[()] = B2[()] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -369,6 +415,10 @@ def test_storage_align(): _check(storage_align_func, compacted_storage_align_func) +def test_match_buffer(): + _check(match_buffer_func, compacted_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -378,3 +428,4 @@ def test_storage_align(): test_symbolic() test_complex() test_storage_align() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index bb40c39ad8..9cb69a6faa 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -16,128 +16,132 @@ # under the License. import tvm from tvm import te, tir - - -def test_inject_software_pipeline(): - n = 100 - m = 4 - num_stages = 3 - - def original(): - tx = te.thread_axis("threadIdx.x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - ib.scope_attr(tx, "thread_extent", 1) - ib.scope_attr(None, "pipeline_scope", num_stages) - with ib.for_range(0, n) as i: - B = ib.allocate("float32", m, name="B", scope="shared") - with ib.for_range(0, m) as j: - B[j] = A[i * m + j] - with ib.for_range(0, m) as k: - C[k] = B[k] + 1 - stmt = ib.get() - mod = tvm.IRModule({"main": tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)}) - return mod - - def non_native_transformed(): - tx = te.thread_axis("threadIdx.x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - ib.scope_attr(tx, "thread_extent", 1) - B = ib.allocate("float32", num_stages * m, name="B", scope="shared") - - with ib.for_range(0, num_stages - 1) as i: - with ib.for_range(0, m) as j: - B[i * m + j] = A[i * 4 + j] - - ib.emit(tir.call_intrin("int32", "tir.tvm_storage_sync", "shared")) - - with ib.new_scope(): - ib.scope_attr(None, "pipeline_scope", 1) - with ib.for_range(0, n - (num_stages - 1)) as i: - with ib.for_range(0, m) as j: - B[tir.indexmod(i + (num_stages - 1), num_stages) * m + j] = A[ - i * m + j + (num_stages - 1) * m - ] - with ib.for_range(0, m) as k: - C[k] = B[tir.indexmod(i, num_stages) * m + k] + 1 - ib.emit(tir.call_intrin("int32", "tir.tvm_storage_sync", "shared")) - - with ib.for_range(0, num_stages - 1) as i: - with ib.for_range(0, m) as k: - C[k] = B[tir.indexmod(i + n - (num_stages - 1), num_stages) * m + k] + 1 - - stmt = ib.get() - mod = tvm.IRModule({"main": tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)}) - return tvm.tir.transform.Simplify()(mod) - - def native_transformed(): - tx = te.thread_axis("threadIdx.x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - ib.scope_attr(tx, "thread_extent", 1) - B = ib.allocate("float32", num_stages * m, name="B", scope="shared") - - pipeline = tir.Var("pipeline", "handle") - - ib.emit(lambda body:tir.LetStmt(pipeline, tir.call_intrin("handle", "tir.tvm_create_pipeline"), body)) - # ib.new_scope() - with ib.for_range(0, num_stages - 1) as i: - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_producer_acquire", pipeline)) - with ib.for_range(0, m) as j: - B[i * m + j] = A[i * 4 + j] - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_producer_commit", pipeline)) - - with ib.new_scope(): - ib.scope_attr(None, "pipeline_scope", 1) - with ib.for_range(0, n - (num_stages - 1)) as i: - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_producer_acquire", pipeline)) - with ib.for_range(0, m) as j: - B[tir.indexmod(i + (num_stages - 1), num_stages) * m + j] = A[ - i * m + j + (num_stages - 1) * m - ] - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_producer_commit", pipeline)) - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_consumer_wait", pipeline)) - ib.emit(tir.call_intrin("int32", "tir.tvm_storage_sync", "shared")) - with ib.for_range(0, m) as k: - C[k] = B[tir.indexmod(i, num_stages) * m + k] + 1 - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_consumer_release", pipeline)) - - with ib.for_range(0, num_stages - 1) as i: - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_consumer_wait", pipeline)) - ib.emit(tir.call_intrin("int32", "tir.tvm_storage_sync", "shared")) - with ib.for_range(0, m) as k: - C[k] = B[tir.indexmod(i + n - (num_stages - 1), num_stages) * m + k] + 1 - ib.emit(tir.call_intrin("handle", "tir.tvm_pipeline_consumer_release", pipeline)) - - stmt = ib.get() - mod = tvm.IRModule({"main": tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)}) - return tvm.tir.transform.Simplify()(mod) - - mod = original() - - opt = tvm.transform.Sequential( - [tvm.tir.transform.InjectSoftwarePipeline(), tvm.tir.transform.Simplify()] - ) - with tvm.transform.PassContext( - config={"tir.InjectSoftwarePipeline": {"use_native_pipeline": False}} - ): - transformed_mod = opt(mod) - - tvm.ir.assert_structural_equal(transformed_mod['main'].body, - non_native_transformed()['main'].body, True) - - with tvm.transform.PassContext( - config={"tir.InjectSoftwarePipeline": {"use_native_pipeline": True}} - ): - with tvm.target.Target("cuda --arch=sm_86"): - transformed_mod = opt(mod) - tvm.ir.assert_structural_equal(transformed_mod['main'].body, native_transformed()['main'].body, - True) +from tvm.script import ty + +def _check(original, transformed, use_native_pipeline): + func = original + mod = tvm.IRModule.from_expr(func) + if use_native_pipeline: + with tvm.transform.PassContext( + config={"tir.InjectSoftwarePipeline": {"use_native_pipeline": True}} + ): + with tvm.target.Target("cuda --arch=sm_86"): + mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + else: + mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(tvm.script.asscript(mod)) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +@tvm.script.tir +def software_pipeline(a : ty.handle, c : ty.handle) -> None: + A = tir.match_buffer(a, [100, 4], dtype="float32") + C = tir.match_buffer(c, [100, 4], dtype="float32") + for tx in tir.thread_binding(0, 1, 'threadIdx.x'): + for i in range(0, 100, annotations={"pipeline_scope": 3}): + with tir.block([], ""): + B = tir.alloc_buffer([1, 4], dtype="float32", scope="shared") + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([A[i, j]]) + tir.writes(B[0, j]) + B[0, j] = A[i, j] + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([B[0, j]]) + tir.writes([C[i, j]]) + C[i, j] = B[0, j] + tir.float32(1) + + +@tvm.script.tir +def transformed_non_native_software_pipeline(a : ty.handle, c : ty.handle) -> None: + C = tir.match_buffer(c, [100, 4]) + A = tir.match_buffer(a, [100, 4]) + for tx in tir.thread_binding(0, 1, thread = "threadIdx.x"): + with tir.block([], ""): + tir.reads([A[0:100, 0:4]]) + tir.writes([C[0:100, 0:4]]) + B = tir.alloc_buffer([3, 1, 4], scope="shared") + for i, j in tir.grid(2, 4): + with tir.block([], ""): + tir.reads([A[i, j]]) + tir.writes([B[0:3, 0, j]]) + B[i, 0, j] = A[i, j] + tir.evaluate(tir.tvm_storage_sync("shared", dtype="int32")) + for i in tir.serial(0, 98, annotations = {"pipeline_scope":1}): + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([A[(i + 2), j]]) + tir.writes([B[0:3, 0, j]]) + B[tir.floormod((i + 2), 3), 0, j] = A[(i + 2), j] + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([B[0:3, 0, j]]) + tir.writes([C[i, j]]) + C[i, j] = (B[tir.floormod(i, 3), 0, j] + tir.float32(1)) + tir.evaluate(tir.tvm_storage_sync("shared", dtype="int32")) + for i, j in tir.grid(2, 4): + with tir.block([], ""): + tir.reads([B[0:3, 0, j]]) + tir.writes([C[(i + 98), j]]) + C[(i + 98), j] = (B[tir.floormod((i + 2), 3), 0, j] + tir.float32(1)) + + +@tvm.script.tir +def transformed_native_software_pipeline(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [100, 4]) + A = tir.match_buffer(a, [100, 4]) + + for tx in tir.thread_binding(0, 1, thread = "threadIdx.x"): + with tir.block([], ""): + tir.reads([A[0:100, 0:4]]) + tir.writes([C[0:100, 0:4]]) + B = tir.alloc_buffer([3, 1, 4], scope="shared") + pipeline: ty.handle = tir.tvm_create_pipeline(dtype="handle") + for i in tir.serial(0, 2): + tir.evaluate(tir.tvm_pipeline_producer_acquire(pipeline, dtype="handle")) + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([A[i, j]]) + tir.writes([B[0:3, 0, j]]) + B[i, 0, j] = A[i, j] + tir.evaluate(tir.tvm_pipeline_producer_commit(pipeline, dtype="handle")) + for i in tir.serial(0, 98, annotations = {"pipeline_scope":1}): + tir.evaluate(tir.tvm_pipeline_producer_acquire(pipeline, dtype="handle")) + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([A[(i + 2), j]]) + tir.writes([B[0:3, 0, j]]) + B[tir.floormod((i + 2), 3), 0, j] = A[(i + 2), j] + tir.evaluate(tir.tvm_pipeline_producer_commit(pipeline, dtype="handle")) + tir.evaluate(tir.tvm_pipeline_consumer_wait(pipeline, dtype="handle")) + tir.evaluate(tir.tvm_storage_sync("shared", dtype="int32")) + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([B[0:3, 0, j]]) + tir.writes([C[i, j]]) + C[i, j] = (B[tir.floormod(i, 3), 0, j] + tir.float32(1)) + tir.evaluate(tir.tvm_pipeline_consumer_release(pipeline, dtype="handle")) + for i in tir.serial(0, 2): + tir.evaluate(tir.tvm_pipeline_consumer_wait(pipeline, dtype="handle")) + tir.evaluate(tir.tvm_storage_sync("shared", dtype="int32")) + for j in tir.serial(0, 4): + with tir.block([], ""): + tir.reads([B[0:3, 0, j]]) + tir.writes([C[(i + 98), j]]) + C[(i + 98), j] = (B[tir.floormod((i + 2), 3), 0, j] + tir.float32(1)) + tir.evaluate(tir.tvm_pipeline_consumer_release(pipeline, dtype="handle")) + + +def test_inject_non_native_software_pipeline(): + _check(software_pipeline, transformed_non_native_software_pipeline, False) + + +def test_inject_native_software_pipeline(): + _check(software_pipeline, transformed_native_software_pipeline, True) if __name__ == "__main__": - test_inject_software_pipeline() + test_inject_non_native_software_pipeline() + test_inject_native_software_pipeline() diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 3fb8331d39..badf5e0e4d 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -18,6 +18,8 @@ from tvm import tir from tvm.script import ty +# pylint: disable=no-self-argument + @tvm.script.tir class WithInit: @@ -43,11 +45,46 @@ def main(a: ty.handle, b: ty.handle) -> None: B[i] += A[i, j, k] +@tvm.script.tir +class InitWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with tir.init(): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + +@tvm.script.tir +class BranchWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + def test_lower_reduction(): origin_mod = WithInit() mod = tvm.tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, WithBranch(), True) +def test_lower_match_buffer(): + origin_mod = InitWithMatchBuffer() + mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) + + if __name__ == "__main__": test_lower_reduction() + test_lower_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_lower_logical_layout.py b/tests/python/unittest/test_tir_transform_lower_logical_layout.py index 5cfeda3b37..73d140a2e2 100644 --- a/tests/python/unittest/test_tir_transform_lower_logical_layout.py +++ b/tests/python/unittest/test_tir_transform_lower_logical_layout.py @@ -32,8 +32,8 @@ def _check(original, transformed): def _check_fail(original): mod = tvm.IRModule.from_expr(original) - with pytest.raises(tvm.TVMError): - mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) + with pytest.raises(ValueError): mod = tvm.tir.transform.LowerLogicalLayout()(mod) diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index d42c5e1f86..792577cb11 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -115,6 +115,85 @@ def transformed_func() -> None: ) +@tvm.script.tir +def match_buffer_func() -> None: + C = tir.alloc_buffer((128, 128)) + with tir.block([128]) as [vi]: + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def transformed_match_buffer_func() -> None: + for i in range(0, 128): + with tir.block([128]) as [vi]: + tir.bind(vi, i) + C = tir.alloc_buffer((128, 128)) + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + A_cache = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[(v * 128) : ((v * 128) + 128)]]) + tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) + tir.evaluate( + tir.call_extern( + "test", + A_cache.data, + (v * 128), + 128, + A.data, + (v * 128), + 128, + dtype="float32", + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + tir.reads(A[vi * 128 : vi * 128 + 128]) + tir.writes(B[vi * 128 : vi * 128 + 128]) + A_cache = tir.alloc_buffer([1024]) + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([A_cache[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern( + "test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -123,6 +202,16 @@ def test_locate_buffer_allocation(): _check(original_func, transformed_func) +def test_match_buffer_allocation(): + _check(match_buffer_func, transformed_match_buffer_func) + + +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + if __name__ == "__main__": test_elementwise() test_locate_buffer_allocation() + test_match_buffer_allocation() + test_opaque_access() diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index a4d2dec0cc..4798e9e098 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -177,19 +177,6 @@ def test_complete_part_region(): _check_elementwise(func_with_part_access_region) -def test_complete_opaque_block_error(): - def render(e): - pass - - override_renderer(render) - - try: - from_source(func_with_opaque_block) - except tvm.error.DiagnosticError: - return - assert False - - @tvm.script.tir def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: data_buf = tir.match_buffer(data, (16, 16), "float32") @@ -255,10 +242,46 @@ def test_complete_buffer_indices(): tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) +@tvm.script.tir +def match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +@tvm.script.tir +def expected_match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, 0:16]) + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + tir.reads([]) + tir.writes(A0[0:16]) + for j in range(0, 16): + with tir.block([]) as []: + tir.reads([]) + tir.writes(A0[j]) + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +def test_complete_match_buffer(): + tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() test_complete_with_root() - test_complete_opaque_block_error() test_complete_part_region() test_complete_buffer_indices() + test_complete_match_buffer() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 71e0d1ba52..7aeceeccfa 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -202,7 +202,7 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: with tir.block([16, 16]) as [vi, vj]: - A = tir.match_buffer_region(vi) # error + A = tir.match_buffer(vi) # error tir.evaluate(1.0) @@ -431,4 +431,4 @@ def render(e): test_error_index_with_stop_slice() test_mismatch_args() test_tvm_exception_catch() - test_match_buffer_shape_mismatch() \ No newline at end of file + test_match_buffer_shape_mismatch()