From d168566bd10b7496bfcd53a06e57245c21b6891d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 10 Nov 2021 13:23:16 +0800 Subject: [PATCH] [SparseTIR] Parser, Printer, Roundtrip (#14) * SparseBlock scope handler (part 1) * SparseBlock scope handler (part 2) * SparseBlock scope handler (part 3) * SparseBlock scope handler (fix 1) * Add SparseBufferLoad/Store on Python side * Parser for SparseBufferLoad/Store * Add SparseBlock to Python __init__ * StmtFunctor for SparseBlock * Ensure at least one dimension for SparseBuffer * Make `axis` field of SpIterVar mandatory * SparseBlock scope handler (fix 2) * Update Axis syntax by removing `name` parameter * Move to intrin.py * Add filed `from_sparse` to DenseFixedAxis * SparseTIR script printer * Roundtrip test * `update_symbol` bug fix * Fix attr visit in SparseBuffer * Define then compare in SparseBlock * Fix printer bug for SparseBuffer * Enable graph match for Axis and SparseBuffer * Complete HashReduce and EqualReduce for AxisTree and SparseBuffer * Fix typo * Rename test * Bug fix 1 * Bug fix 2 * Add more tests --- include/tvm/tir/expr.h | 2 +- include/tvm/tir/sparse.h | 88 +++++--- include/tvm/tir/stmt.h | 12 +- include/tvm/tir/stmt_functor.h | 4 + python/tvm/script/context_maintainer.py | 2 +- python/tvm/script/parser.py | 10 + python/tvm/script/tir/intrin.py | 43 +++- python/tvm/script/tir/scope_handler.py | 73 ++++++ python/tvm/script/tir/special_stmt.py | 102 ++++----- python/tvm/tir/__init__.py | 7 +- python/tvm/tir/expr.py | 22 ++ python/tvm/tir/sparse.py | 17 +- python/tvm/tir/stmt.py | 25 +++ src/printer/tvmscript_printer.cc | 195 ++++++++++++++++ src/tir/ir/expr.cc | 1 + src/tir/ir/sparse.cc | 75 ++++--- src/tir/ir/stmt_functor.cc | 24 ++ src/tir/transforms/lower_sparse_tir.cc | 8 +- .../test_tir_sparse_script_roundtrip.py | 211 ++++++++++++++++++ .../unittest/test_tir_sparse_scripts.py | 96 -------- 20 files changed, 783 insertions(+), 234 deletions(-) create mode 100644 tests/python/unittest/test_tir_sparse_script_roundtrip.py delete mode 100644 tests/python/unittest/test_tir_sparse_scripts.py diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 93fb56a4e7c0..b17db12d714d 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -656,7 +656,7 @@ class BufferLoad : public PrimExpr { */ class SparseBufferLoadNode : public PrimExprNode { public: - /*! \brief The buffer variable. */ + /*! \brief The buffer to be loaded. */ SparseBuffer buffer; /*! \brief The indices location to be loaded. */ Array indices; diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 6be1062c3c86..a0dd5db19107 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -99,23 +99,48 @@ class DenseAxis : public Axis { TVM_DEFINE_OBJECT_REF_METHODS(DenseAxis, Axis, DenseAxisNode); }; +/*! + * \brief Sparse axis whose column indices is not consecutive. + */ +class SparseAxisNode : public AxisNode { + public: + static constexpr const char* _type_key = "tir.sparse.SparseAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode); +}; + +/*! + * \brief Managed reference to SparseAxisNode. + * \sa SparseAxisNode + */ +class SparseAxis : public Axis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode); +}; + /*! * \brief Dense axis with fixed length per row. */ class DenseFixedAxisNode : public DenseAxisNode { public: + Optional from_sparse; + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("length", &length); + v->Visit("from_sparse", &from_sparse); } - bool SEqualReduce(const DenseAxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length); + bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(name, other->name) && equal(length, other->length) && + equal(from_sparse, other->from_sparse); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); + hash_reduce(from_sparse); } static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; @@ -128,7 +153,8 @@ class DenseFixedAxisNode : public DenseAxisNode { */ class DenseFixedAxis : public DenseAxis { public: - TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length); + TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length, + Optional from_sparse = NullOpt); TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode); }; @@ -144,10 +170,12 @@ class DenseVariableAxisNode : public DenseAxisNode { } bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(indptr); @@ -168,24 +196,6 @@ class DenseVariableAxis : public DenseAxis { TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); }; -/*! - * \brief Sparse axis whose column indices is not consecutive. - */ -class SparseAxisNode : public AxisNode { - public: - static constexpr const char* _type_key = "tir.sparse.SparseAxis"; - TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode); -}; - -/*! - * \brief Managed reference to SparseAxisNode. - * \sa SparseAxisNode - */ -class SparseAxis : public Axis { - public: - TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode); -}; - /*! * \brief Sparse axis with fixed number of non-zero columns per row. */ @@ -203,11 +213,13 @@ class SparseFixedAxisNode : public SparseAxisNode { } bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(indices, other->indices) && equal(num_cols, other->num_cols); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(indices); @@ -245,11 +257,13 @@ class SparseVariableAxisNode : public SparseAxisNode { } bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr) && equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(indptr); @@ -277,13 +291,27 @@ class SparseVariableAxis : public SparseAxis { class AxisTreeNode : public Object { public: // unordered map that stores the parent relationship between axes. - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> parent; + Map> parent; // unordered map that stores the children relationship between axes. - std::unordered_map, Array, ObjectPtrHash, ObjectPtrEqual> children; + Map, Array> children; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("parent", &parent); + v->Visit("children", &children); + } + + bool SEqualReduce(const AxisTreeNode* other, SEqualReducer equal) const { + return equal(parent, other->parent) && equal(children, other->children); + } - void VisitAttrs(AttrVisitor* v) {} + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(parent); + hash_reduce(children); + } static constexpr const char* _type_key = "tir.sparse.AxisTree"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object); }; @@ -313,22 +341,26 @@ class SparseBufferNode : public Object { inline int ndim() const { return static_cast(axes.size()); } void VisitAttrs(AttrVisitor* v) { - v->Visit("length", &axes); - v->Visit("num_cols", &data); + v->Visit("axes", &axes); + v->Visit("data", &data); v->Visit("name", &name); } bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); hash_reduce(axes); hash_reduce(data); hash_reduce(name); } static constexpr const char* _type_key = "tir.sparse.SparseBuffer"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object); }; @@ -359,7 +391,7 @@ class SpIterVarNode : public Object { PrimExpr max_extent; SpIterKind kind; bool is_reduction; - Optional axis; + Axis axis; void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); @@ -392,7 +424,7 @@ class SpIterVarNode : public Object { class SpIterVar : public ObjectRef { public: TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, - Optional axis = NullOpt); + Axis axis); /*! * \return the corresponding var in the IterVar. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b43904f463da..fc25f6885ac2 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -335,11 +335,11 @@ class BufferStore : public Stmt { * buffer[i, j] = value; * * \endcode - * \sa SparseBufferLoad + * \sa SparseBufferStore */ class SparseBufferStoreNode : public StmtNode { public: - /*! \brief The buffer variable. */ + /*! \brief The sparse buffer to be accessed. */ SparseBuffer buffer; /*! \brief The value to be stored. */ PrimExpr value; @@ -1303,17 +1303,17 @@ class SparseBlockNode : public StmtNode { } bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const { - return equal(sp_iter_vars, other->sp_iter_vars) && - equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) && - equal(body, other->body) && equal(init, other->init); + return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) && + equal(body, other->body) && equal(init, other->init) && + equal(sp_struct2param_map, other->sp_struct2param_map); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(sp_iter_vars); - hash_reduce(sp_struct2param_map); hash_reduce(name); hash_reduce(body); hash_reduce(init); + hash_reduce(sp_struct2param_map); } static constexpr const char* _type_key = "tir.SparseBlock"; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 7185829f2b70..e8adb75a496a 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -99,6 +99,7 @@ class StmtFunctor { virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SparseBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -126,6 +127,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); IR_STMT_FUNCTOR_DISPATCH(BlockNode); IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); + IR_STMT_FUNCTOR_DISPATCH(SparseBlockNode); return vtable; } }; @@ -169,6 +171,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const BlockNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; + void VisitStmt_(const SparseBlockNode* op) override; }; /*! @@ -270,6 +273,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const EvaluateNode* op) override; Stmt VisitStmt_(const BlockNode* op) override; Stmt VisitStmt_(const BlockRealizeNode* op) override; + Stmt VisitStmt_(const SparseBlockNode* op) override; /*! * \brief Alternative advance method for SeqStmtNode. * diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 724f9a27078b..b76d73261091 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -219,7 +219,7 @@ def update_symbol( self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node ): """Append a symbol into current scope""" - if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)): + if isinstance(symbol, (Buffer, SparseBuffer, Axis)): if name in self.symbols[0]: self.report_error("Duplicate Buffer name: " + symbol.name, node.span) self.symbols[0][name] = symbol diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 02582e29e323..06c3bb997a29 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -582,6 +582,14 @@ def transform_SubscriptAssign(self, node): indexes, span=tvm_span_from_synr(node.span), ) + elif isinstance(symbol, tvm.tir.sparse.SparseBuffer): + # SparseBufferStore + return tvm.tir.SparseBufferStore( + symbol, + tvm.runtime.convert(rhs, span=rhs_span), + indexes, + span=tvm_span_from_synr(node.span), + ) else: if len(indexes) != 1: self.report_error( @@ -876,6 +884,8 @@ def transform_Subscript(self, node): return BufferSlice( symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) ) + elif isinstance(symbol, tvm.tir.sparse.SparseBuffer): + return tvm.tir.SparseBufferLoad(symbol, indexes, span=tvm_span_from_synr(node.span)) elif isinstance(symbol, tvm.container.Array): if len(indexes) > 1: self.report_error( diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index d577d8fc3fb5..157cc7a722b9 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -17,9 +17,18 @@ """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level import builtins -from typing import List, Any +from typing import List, Optional, Any import tvm.tir +from tvm.ir import Span +from tvm.tir.sparse import ( + Axis, + DenseFixedAxis, + DenseVariableAxis, + SpIterVar, + SparseFixedAxis, + SparseVariableAxis, +) from ..registry import register from ..utils import get_param_list, tvm_span_from_synr @@ -244,3 +253,35 @@ def comm_reducer(lambda_io, identities, span): lambda_output = (lambda_output,) return tvm.tir.CommReducer(x, y, lambda_output, identities, span) + + +@register +def to_dense(axis: Axis, span: Optional[Span] = None): + if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): + return DenseFixedAxis(axis.name + "_dense", axis.length, axis) + else: + return axis + + +@register +def cord(axis: Axis, span: Optional[Span] = None): + # The field `var` and `is_reduction` will be updated in SparseBlock scope handler + var_temp = tvm.te.var() + if isinstance(axis, DenseVariableAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) + else: + return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis) + + +@register +def pos(axis: Axis, span: Optional[Span] = None): + # The field `var` and `is_reduction` will be updated in SparseBlock scope handler + var_temp = tvm.te.var() + if isinstance(axis, DenseFixedAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis) + elif isinstance(axis, DenseVariableAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) + elif isinstance(axis, SparseFixedAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis) + else: + return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 4750ad7626e2..44610c983306 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -19,10 +19,12 @@ from typing import Tuple, Any, Callable, Optional, List, Union, Mapping import synr +from synr.ast import With import tvm.tir from tvm.runtime import Object from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind +from tvm.tir.sparse import SpIterVar from .node import BufferSlice from .utils import buffer_slice_to_region @@ -321,6 +323,77 @@ def enter_scope( ) +@register +class SparseBlock(WithScopeHandler): + """With scope handler of SparseBlock""" + + def __init__(self): + def iter(iters: List, iter_types: str, name: str = "", span: Optional[Span] = None): + assert ( + self.node and self.context and self.body + ), "call 'exit_scope' before 'enter_scope'" + block_info = self.context.block_info_stack[-1] + + if len(iters) != len(self.sp_iters): + self.context.report_error( + "Inconsistent number of sparse iteration variable names, " + + f"there are {len(iters)} iterators but {len(self.sp_iters)} names. " + + "The number of sparse iteration variable names should match the number of iterators.", + self.node.span, + ) + if len(iters) != len(iter_types): + self.context.report_error( + "Inconsistent number of sparse iteration variable types, " + + f"there are {len(iters)} iterators but {len(iter_types)} types. " + + "The number of sparse iteration variable types should match the number of iterators.", + self.node.span, + ) + + sp_iters: List[SpIterVar] = [] + for i, sp_iter in enumerate(iters): + assert isinstance(sp_iter, SpIterVar) + is_reduction = True if iter_types[i] == "R" else False + sp_iters.append( + SpIterVar( + self.sp_iters[i], + sp_iter.max_extent, + sp_iter.kind, + is_reduction, + sp_iter.axis, + ) + ) + + block = tvm.tir.SparseBlock( + sp_iters, + self.context.sp_struct2param_map, + name, + self.body, + block_info.init, + span, + ) + return block + + super().__init__(func=iter, concise_scope=False, def_symbol=True) + self.sp_iters = None + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define sparse iteration variables + assert isinstance( + node, synr.ast.With + ), f"SparseBlockScopeHandler expected to work on synr.ast.With but got {type(node)}" + + vars = WithScopeHandler.get_optional_vars(node, context) + self.sp_iters = [tvm.te.var(var.id.name, "int32") for var in vars] + for sp_iter in self.sp_iters: + context.update_symbol(sp_iter.name, sp_iter, node) + + @register class InitBlock(WithScopeHandler): """With scope handler T.init()""" diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 0d821cd55443..68baf9f8cb4f 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -32,7 +32,6 @@ Axis, DenseFixedAxis, DenseVariableAxis, - SpIterVar, SparseFixedAxis, SparseVariableAxis, ) @@ -854,11 +853,16 @@ class DenseFixed(SpecialStmt): """Special Stmt for creating dense fixed axis.""" def __init__(self): - def dense_fixed(name: str, length: PrimExpr, span: Optional[Span] = None): - var_name = self.node.lhs[0].id.name - axis = DenseFixedAxis(name, length) + def dense_fixed(length: PrimExpr, span: Optional[Span] = None): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`dense_fixed` expected assign to only one var, but got {names}", span + ) + + axis = DenseFixedAxis(names[0], length) self.context.sp_struct2param_map[axis] = [] - self.context.update_symbol(var_name, axis, self.node) + self.context.update_symbol(names[0], axis, self.node) super().__init__(dense_fixed, def_symbol=True) @@ -869,21 +873,25 @@ class DenseVariable(SpecialStmt): def __init__(self): def dense_variable( - name: str, shape: Tuple[PrimExpr, PrimExpr], indptr_var: tvm.tir.Var, idtype: str = "int32", span: Optional[Span] = None, ): - indptr_len, length = shape - var_name = self.node.lhs[0].id.name + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`dense_variable` expected assign to only one var, but got {names}", span + ) + + length, indptr_len = shape indptr_buf = tvm.tir.decl_buffer( - (indptr_len,), dtype=idtype, name=name + "_indptr", span=span + (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) - axis = DenseVariableAxis(name, length, indptr_buf) - self.context.sp_struct2param_map[axis] = indptr_var - self.context.update_symbol(var_name, axis, self.node) - self.context.update_symbol(name + "_indptr", indptr_buf, self.node) + axis = DenseVariableAxis(names[0], length, indptr_buf) + self.context.sp_struct2param_map[axis] = [indptr_var] + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) super().__init__(dense_variable, def_symbol=True) @@ -894,21 +902,25 @@ class SparseFixed(SpecialStmt): def __init__(self): def sparse_fixed( - name: str, shape: Tuple[PrimExpr, PrimExpr, PrimExpr], indices_var: tvm.tir.Var, idtype: str = "int32", span: Optional[Span] = None, ): - var_name = self.node.lhs[0].id.name + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`sparse_fixed` expected assign to only one var, but got {names}", span + ) + length, nnz, nnz_cols = shape indices_buf = tvm.tir.decl_buffer( - (nnz,), dtype=idtype, name=name + "_indices", span=span + (nnz,), dtype=idtype, name=names[0] + "_indices", span=span ) - axis = SparseFixedAxis(name, length, indices_buf, nnz_cols) + axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols) self.context.sp_struct2param_map[axis] = [indices_var] - self.context.update_symbol(var_name, axis, self.node) - self.context.update_symbol(name + "_indices", indices_buf, self.node) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indices", indices_buf, self.node) super().__init__(sparse_fixed, def_symbol=True) @@ -919,26 +931,30 @@ class SparseVariable(SpecialStmt): def __init__(self): def sparse_variable( - name: str, shape: Tuple[PrimExpr, PrimExpr, PrimExpr], data: Tuple[tvm.tir.Var, tvm.tir.Var], idtype: str = "int32", span: Optional[Span] = None, ): - var_name = self.node.lhs[0].id.name + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`sparse_variable` expected assign to only one var, but got {names}", span + ) + length, indptr_len, nnz = shape indptr_var, indices_var = data indptr_buf = tvm.tir.decl_buffer( - (indptr_len,), dtype=idtype, name=name + "_indptr", span=span + (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) indices_buf = tvm.tir.decl_buffer( - (nnz,), dtype=idtype, name=name + "_indices", span=span + (nnz,), dtype=idtype, name=names[0] + "_indices", span=span ) - axis = SparseVariableAxis(name, length, indptr_buf, indices_buf) + axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf) self.context.sp_struct2param_map[axis] = [indptr_var, indices_var] - self.context.update_symbol(var_name, axis, self.node) - self.context.update_symbol(name + "_indptr", indptr_buf, self.node) - self.context.update_symbol(name + "_indices", indices_buf, self.node) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) + self.context.update_symbol(names[0] + "_indices", indices_buf, self.node) super().__init__(sparse_variable, def_symbol=True) @@ -980,35 +996,3 @@ def match_sparse_buffer( ) super().__init__(match_sparse_buffer, def_symbol=True) - - -@register -def to_dense(axis: Axis, span: Optional[Span] = None): - if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): - return DenseFixedAxis(axis.name + "_dense", axis.length) - else: - return axis - - -@register -def cord(axis: Axis, span: Optional[Span] = None): - # The field `var` and `is_reduction` will be updated in SparseBlock scope handler - var_temp = tvm.te.var() - if isinstance(axis, DenseVariableAxis): - return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) - else: - return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False) - - -@register -def pos(axis: Axis, span: Optional[Span] = None): - # The field `var` and `is_reduction` will be updated in SparseBlock scope handler - var_temp = tvm.te.var() - if isinstance(axis, DenseFixedAxis): - return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False) - elif isinstance(axis, DenseVariableAxis): - return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) - elif isinstance(axis, SparseFixedAxis): - return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis) - else: - return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 252f147ab1e0..e242b770e9d2 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -24,14 +24,15 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not -from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle +from .expr import Select, BufferLoad, SparseBufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While -from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt +from .stmt import BufferStore, SparseBufferStore, BufferRealize, Store, ProducerStore +from .stmt import Allocate, AttrStmt from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list -from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize +from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize, SparseBlock from .function import PrimFunc diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 27cf5351a077..4ee6d9505ee9 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1058,6 +1058,28 @@ def __init__(self, buffer, indices, span=None): ) +@tvm._ffi.register_object("tir.SparseBufferLoad") +class SparseBufferLoad(PrimExprWithOp): + """SparseBufferLoad node. + + Parameters + ---------- + buffer : SparseBuffer + The buffer to be loaded. + + indices : List[PrimExpr] + The indices location to be loaded. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, buffer, indices, span=None): + self.__init_handle_by_constructor__( + _ffi_api.SparseBufferLoad, buffer, indices, span # type: ignore + ) + + @tvm._ffi.register_object("tir.ProducerLoad") class ProducerLoad(PrimExprWithOp): """Producer load node. diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 574ccc2352a6..07fd48208d1f 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -61,13 +61,17 @@ class DenseFixedAxis(DenseAxis): length : PrimExpr The length of the axis + + from_sparse : Optional[SparseAxis] + The SparseAxis that this axis is created from """ name: str length: PrimExpr + from_sparse: Optional[SparseAxis] - def __init__(self, name, length): - self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length) # type: ignore + def __init__(self, name, length, from_sparse=None): + self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length, from_sparse) # type: ignore @tvm._ffi.register_object("tir.sparse.DenseVariableAxis") @@ -218,23 +222,22 @@ class SpIterVar(Object): is_reduction : bool Whether the SpIterVar is a reduction iterator - axis : Optional[Axis] - The axis over which the SpIterVar iterates. Required to be defined - when `kind` is not `DenseFixed` + axis : Axis + The axis over which the SpIterVar iterates """ var: Var max_extent: PrimExpr kind: int is_reduction: bool - axis: Optional[Axis] + axis: Axis DenseFixed = 0 DenseVariable = 1 SparseFixed = 2 SparseVariable = 3 - def __init__(self, var, max_extent, kind, is_reduction, axis=None): + def __init__(self, var, max_extent, kind, is_reduction, axis): self.__init_handle_by_constructor__( _ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 68b5eca8ecda..7a3677cf8f3a 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -245,6 +245,31 @@ def __init__(self, buffer, value, indices, span=None): ) +@tvm._ffi.register_object("tir.SparseBufferStore") +class SparseBufferStore(Stmt): + """SparseBufferStore node. + + Parameters + ---------- + buffer : SparseBuffer + The sparse buffer to be accessed. + + value : PrimExpr + The value to be stored. + + indices : List[PrimExpr] + The indices location to be stored. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, buffer, value, indices, span=None): + self.__init_handle_by_constructor__( + _ffi_api.SparseBufferStore, buffer, value, indices, span # type: ignore + ) + + @tvm._ffi.register_object("tir.BufferRealize") class BufferRealize(Stmt): """Buffer realize node. diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index f1c47e78bc45..c83399db90fc 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -119,6 +119,10 @@ class TVMScriptPrinter : public StmtFunctor, std::unordered_map memo_buf_; /*! \brief Map from Buffer to Declaration Doc */ std::unordered_map memo_buf_decl_; + /*! \brief Map from SparseBuffer to Doc */ + std::unordered_map memo_sp_buf_; + /*! \brief Map from Axis in SparseTIR to Doc */ + std::unordered_map memo_sp_axis_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief number of children of current node's parent */ @@ -164,6 +168,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override; + Doc VisitExpr_(const SparseBufferLoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override; @@ -178,6 +183,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const AssertStmtNode* op) override; Doc VisitStmt_(const StoreNode* op) override; Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const SparseBufferStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; @@ -187,6 +193,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; + Doc VisitStmt_(const SparseBlockNode* op) override; Doc VisitStmtDefault_(const Object* op) override; Doc VisitType_(const PrimTypeNode* node) override; @@ -200,6 +207,8 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintRange(const RangeNode* op); Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintSparseBuffer(const SparseBufferNode* op); + Doc PrintSpAxis(const AxisNode* op); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); Doc PrintBlockVarRemaps(); @@ -207,6 +216,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); virtual Doc PrintBlockName(const BlockNode* block_op); + Doc PrintSparseBlockName(const SparseBlockNode* op); + Doc PrintSparseStructDefinitions(const SparseBlockNode* sp_block); + Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); Doc PrintCommReducer(const CommReducerNode* op); @@ -216,6 +228,8 @@ class TVMScriptPrinter : public StmtFunctor, Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); + Doc AllocSparseBuf(const SparseBuffer& buffer); + Doc AllocAxis(const Axis& axis); void TryDeallocVar(const Var& var); bool ContainsOptionalInfo(const Stmt& stmt); @@ -423,6 +437,42 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +Doc TVMScriptPrinter::AllocSparseBuf(const SparseBuffer& buffer) { + const auto& it = memo_sp_buf_.find(buffer); + if (it != memo_sp_buf_.end()) { + return it->second; + } + std::string name = buffer->name; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "buf_" + name; + } + Doc val = GetUniqueName(name); + memo_sp_buf_[buffer] = val; + return val; +} + +Doc TVMScriptPrinter::AllocAxis(const Axis& axis) { + const auto& it = memo_sp_axis_.find(axis); + if (it != memo_sp_axis_.end()) { + return it->second; + } + Doc val; + const auto* df_axis = axis.as(); + + if (df_axis != nullptr && df_axis->from_sparse.defined()) { + val << tir_prefix_ << ".to_dense(" << Print(df_axis->from_sparse.value()) << ")"; + } else { + std::string name = axis->name; + if (name.length() == 0 || !std::isalnum(name[0])) { + name = "axis_" + name; + } + val = GetUniqueName(name); + } + + memo_sp_axis_[axis] = val; + return val; +} + /*! * \brief Check if any optional information exists in annotate_ for * a given Stmt. @@ -519,6 +569,10 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintArray(node.as()); } else if (node->IsInstance()) { return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintSparseBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintSpAxis(node.as()); } else if (node->IsInstance()) { return PrintString(node.as()); } else if (node->IsInstance()) { @@ -669,6 +723,13 @@ Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_p return doc; } +Doc TVMScriptPrinter::VisitExpr_(const SparseBufferLoadNode* op, ExprPrecedence* out_precedence) { + *out_precedence = ExprPrecedence::kIdentity; + Doc doc; + doc << Print(op->buffer) << Print(op->indices); + return doc; +} + Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; @@ -987,6 +1048,12 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const SparseBufferStoreNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + /*! Helper functions for block printing. */ Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; @@ -1152,6 +1219,119 @@ Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { return doc; } +Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) { + Doc doc; + doc << "with " << tir_prefix_ << ".iter(["; + + int n_iter = static_cast(op->sp_iter_vars.size()); + + std::string iter_types = ""; + std::vector sp_iter_docs; + std::vector sp_iter_name_docs; + iter_types.reserve(n_iter); + sp_iter_docs.reserve(n_iter); + sp_iter_name_docs.reserve(n_iter); + + for (int i = 0; i < n_iter; ++i) { + const SpIterVar& sp_iter = op->sp_iter_vars[i]; + Doc iter_doc; + if (sp_iter->kind == SpIterKind::kDenseFixed || sp_iter->kind == SpIterKind::kDenseVariable) { + iter_doc << tir_prefix_ << ".cord(" << sp_iter->axis->name << ")"; + } else { + iter_doc << tir_prefix_ << ".pos(" << sp_iter->axis->name << ")"; + } + var_not_in_headers_.insert(sp_iter->var.get()); + sp_iter_docs.push_back(iter_doc); + sp_iter_name_docs.push_back(Print(sp_iter->var)); + iter_types += sp_iter->is_reduction ? "R" : "S"; + } + + doc << PrintSep(sp_iter_docs, Doc::Text(", ")) << "], " << Doc::StrLiteral(iter_types) << ", " + << Doc::StrLiteral(op->name) << ") as [" << PrintSep(sp_iter_name_docs, Doc::Text(", ")) + << "]:"; + + return doc; +} + +Doc TVMScriptPrinter::VisitStmt_(const SparseBlockNode* op) { + Doc doc = PrintOptionalInfo(GetRef(op)); + doc << PrintSparseBlockName(op); + + Doc body; + if (op->init.defined()) { + Doc init; + init << "with " << tir_prefix_ << ".init():"; + init << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value())); + body << init << Doc::NewLine(); + } + body << PrintBody(op->body); + doc << Doc::Indent(4, Doc::NewLine() << body); + + for (const SpIterVar& sp_iter : op->sp_iter_vars) { + TryDeallocVar(sp_iter->var); + } + return doc; +} + +Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_block) { + std::vector axis_docs; + std::vector sp_buf_docs; + + for (auto it : sp_block->sp_struct2param_map) { + Doc doc; + doc << Print(it.first) << " = " << tir_prefix_ << "."; + + if (const auto* sp_buffer = it.first.as()) { + ICHECK_EQ(it.second.size(), 1); + Doc axes_doc; + if (sp_buffer->axes.size() != 1) { + std::vector axes_docs; + axes_docs.reserve(sp_buffer->axes.size()); + for (const Axis& axis : sp_buffer->axes) { + axes_docs.push_back(Print(axis)); + } + axes_doc << PrintSep(axes_docs, Doc::Text(", ")); + } else { + axes_doc << Print(sp_buffer->axes[0]) << ","; + } + + doc << "match_sparse_buffer(" << Print(it.second[0]) << ", (" << axes_doc << "), " + << Print(sp_buffer->data->shape[0]) << ", " << PrintDType(sp_buffer->data->dtype) << ")"; + sp_buf_docs.push_back(doc); + continue; + } + + if (const auto* df_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 0); + doc << "dense_fixed(" << Print(df_axis->length) << ")"; + } else if (const auto* dv_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 1); + doc << "dense_variable((" << Print(dv_axis->length) << ", " + << Print(dv_axis->indptr->shape[0]) << "), " << Print(it.second[0]) << ", " + << PrintDType(dv_axis->indptr->dtype) << ")"; + } else if (const auto* sf_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 1); + doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0]) + << ", " << Print(sf_axis->num_cols) << "), " << Print(it.second[0]) << ", " + << PrintDType(sf_axis->indices->dtype) << ")"; + } else if (const auto* sv_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 2); + doc << "sparse_variable((" << Print(sv_axis->length) << ", " + << Print(sv_axis->indptr->shape[0]) << ", " << Print(sv_axis->indices->shape[0]) << "), (" + << Print(it.second[0]) << ", " << Print(it.second[1]) << "), " + << PrintDType(sv_axis->indptr->dtype) << ")"; + } else { + ICHECK(false) << "Cannot reach here"; + } + axis_docs.push_back(doc); + } + + Doc res; + res << PrintSep(axis_docs, Doc::NewLine()) << Doc::NewLine() + << PrintSep(sp_buf_docs, Doc::NewLine()) << Doc::NewLine(); + return res; +} + Doc TVMScriptPrinter::PrintBody(const Stmt& body) { int memo_num_child, memo_current_num; std::swap(memo_num_child, num_child_); @@ -1206,6 +1386,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { memo_var_.clear(); memo_buf_.clear(); memo_buf_decl_.clear(); + memo_sp_buf_.clear(); var_not_in_headers_.clear(); buf_not_in_headers_.clear(); // print signature @@ -1230,6 +1411,10 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second]; body << ")" << Doc::NewLine(); } + // print sparse data structure definitions + if (const auto* sp_block = op->body.as()) { + body << PrintSparseStructDefinitions(sp_block); + } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && @@ -1344,6 +1529,16 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } +Doc TVMScriptPrinter::PrintSparseBuffer(const SparseBufferNode* op) { + const SparseBuffer& buffer = GetRef(op); + return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocSparseBuf(buffer); +} + +Doc TVMScriptPrinter::PrintSpAxis(const AxisNode* op) { + const Axis& axis = GetRef(op); + return meta_.InMeta(axis) ? meta_.GetMetaNode(axis) : AllocAxis(axis); +} + Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; if (op->region.size() == 0) { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 6a0c26ed828d..3f0a94c7d141 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1087,6 +1087,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // SparseBufferLoad SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); + node->dtype = buffer->data->dtype; node->buffer = std::move(buffer); node->indices = std::move(indices); node->span = std::move(span); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index c3e118611b22..6a59dd0a5e5b 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -43,18 +43,20 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) }); // DenseFixedAxis -DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { +DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length, Optional from_sparse) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); + node->from_sparse = std::move(from_sparse); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); -TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { - return DenseFixedAxis(name, length); -}); +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis") + .set_body_typed([](String name, PrimExpr length, Optional from_sparse) { + return DenseFixedAxis(name, length, from_sparse); + }); // DenseVariableAxis DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { @@ -114,19 +116,26 @@ AxisTree::AxisTree(Array axis_names, Array> axis_parent "axis_parent_names " "array."; ObjectPtr node = make_object(); + Map> parent; + Map, Array> children; for (size_t i = 0; i < axis_names.size(); i++) { // update parent map & children map String axis_name = axis_names[i]; Optional parent_name = axis_parent_names[i]; - node->parent[axis_name] = parent_name; - if (node->children.find(parent_name) != node->children.end()) { - node->children[parent_name].push_back(axis_name); + parent.Set(axis_name, parent_name); + + auto it = children.find(parent_name); + if (it != children.end()) { + Array value = (*it).second; + value.push_back(axis_name); + children.Set(parent_name, std::move(value)); } else { - Array children; - children.push_back(axis_name); - node->children[parent_name] = std::move(children); + Array value{axis_name}; + children.Set(parent_name, std::move(value)); } } + node->parent = std::move(parent); + node->children = std::move(children); data_ = std::move(node); } @@ -140,6 +149,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") // SparseBuffer SparseBuffer::SparseBuffer(Array axes, Buffer data, String name) { ObjectPtr node = make_object(); + CHECK_GT(static_cast(axes.size()), 0) + << "ValueError: A SparseBuffer should have at least one dimension"; node->axes = std::move(axes); node->data = std::move(data); node->name = std::move(name); @@ -153,26 +164,35 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") return SparseBuffer(axes, data, name); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sparse_buffer(" << op->name << ", ["; + for (int i = 0, n = static_cast(op->axes.size()); i < n; ++i) { + const Axis& axis = op->axes[i]; + p->stream << axis; + if (i < n - 1) { + p->stream << ", "; + } + } + p->stream << "], " << op->data << ")"; + }); + // SpIterVar -SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, - Optional axis) { +SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, Axis axis) { ObjectPtr node = make_object(); arith::Analyzer ana; - if (axis.defined()) { - CHECK(ana.CanProveEqual(axis.value()->length, max_extent)); - } - if (kind != SpIterKind::kDenseFixed) { - CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must " - "specify the axis over which the SpIterVar iterates"; - const char* err_str = "ValueError: The given kind doesn't match the type of the given axis"; - if (kind == SpIterKind::kDenseVariable) { - CHECK(axis.value()->IsInstance()) << err_str; - } else if (kind == SpIterKind::kSparseFixed) { - CHECK(axis.value()->IsInstance()) << err_str; - } else if (kind == SpIterKind::kSparseVariable) { - CHECK(axis.value()->IsInstance()) << err_str; - } + CHECK(ana.CanProveEqual(axis->length, max_extent)); + const char* err_str = "ValueError: The given kind doesn't match the type of the given axis"; + if (kind == SpIterKind::kDenseFixed) { + CHECK(!axis->IsInstance()) << err_str; + } else if (kind == SpIterKind::kDenseVariable) { + CHECK(axis->IsInstance()) << err_str; + } else if (kind == SpIterKind::kSparseFixed) { + CHECK(axis->IsInstance()) << err_str; + } else if (kind == SpIterKind::kSparseVariable) { + CHECK(axis->IsInstance()) << err_str; } node->var = Var(std::move(var)); @@ -186,8 +206,7 @@ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_redu TVM_REGISTER_NODE_TYPE(SpIterVarNode); TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") - .set_body_typed([](Var var, PrimExpr max_extent, int kind, bool is_reduction, - Optional axis) { + .set_body_typed([](Var var, PrimExpr max_extent, int kind, bool is_reduction, Axis axis) { return SpIterVar(var, max_extent, SpIterKind(kind), is_reduction, axis); }); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 2a0c43904c70..723561aa5325 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -153,6 +153,13 @@ void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { this->VisitStmt(op->block); } +void StmtVisitor::VisitStmt_(const SparseBlockNode* op) { + if (op->init.defined()) { + this->VisitStmt(op->init.value()); + } + this->VisitStmt(op->body); +} + class StmtMutator::Internal { public: /*! @@ -563,6 +570,23 @@ Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { } } +Stmt StmtMutator::VisitStmt_(const SparseBlockNode* op) { + Optional init = NullOpt; + if (op->init.defined()) { + init = VisitStmt(op->init.value()); + } + Stmt body = VisitStmt(op->body); + + if (init.same_as(op->init) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->init = std::move(init); + n->body = std::move(body); + return Stmt(n); + } +} + // Implementations of IRTransform, PostOrderVisit and Substitute class IRApplyVisit : public StmtExprVisitor { public: diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 34d0b3d05e3b..ee7337d60e7d 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -68,7 +68,7 @@ class AccessAndDependencyCollector : public StmtExprVisitor { const SpIterVar& sp_iter = kv_pair.second[k]; if (sp_iter->kind == SpIterKind::kDenseFixed || sp_iter->kind == SpIterKind::kDenseVariable || - !BufferContainsAxis(buffer, sp_iter->axis.value())) { + !BufferContainsAxis(buffer, sp_iter->axis)) { continue; } @@ -169,7 +169,7 @@ class IndexTransformer : public StmtExprMutator { } else if (kind == SpIterKind::kSparseFixed) { CHECK(!axis->IsInstance()); CHECK(sp_iter->axis.defined()); - const Axis& iterated_axis = sp_iter->axis.value(); + const Axis& iterated_axis = sp_iter->axis; if (const auto* df_axis = axis.as()) { CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); sp_index = GetDenseValue(sp_iter); @@ -192,7 +192,7 @@ class IndexTransformer : public StmtExprMutator { CHECK(kind == SpIterKind::kSparseVariable); CHECK(!axis->IsInstance()); CHECK(sp_iter->axis.defined()); - const Axis& iterated_axis = sp_iter->axis.value(); + const Axis& iterated_axis = sp_iter->axis; if (const auto* df_axis = axis.as()) { CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); sp_index = GetDenseValue(sp_iter); @@ -240,7 +240,7 @@ class IndexTransformer : public StmtExprMutator { PrimExpr GetDenseValue(const SpIterVarNode* sp_iter) { SpIterKind kind = sp_iter->kind; CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable); - Axis iterated_axis = sp_iter->axis.value(); + Axis iterated_axis = sp_iter->axis; std::pair dependent_pair = dependency_map_[GetRef(sp_iter)]; Array buffer_access_iters = buffer_access_map_[dependent_pair.first]; diff --git a/tests/python/unittest/test_tir_sparse_script_roundtrip.py b/tests/python/unittest/test_tir_sparse_script_roundtrip.py new file mode 100644 index 000000000000..17e2d8f04c9b --- /dev/null +++ b/tests/python/unittest/test_tir_sparse_script_roundtrip.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.tir as tir +import tvm.te as te +from tvm.script import tir as T + + +@T.prim_func +def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + k = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), K), m * k, "float32") + C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (I,), n, "float32") + with T.iter([T.cord(I), T.pos(J)], "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + nb = T.var("int32") + mb = T.var("int32") + nnzb = T.var("int32") + blk = T.var("int32") + feat_size = T.var("int32") + I = T.dense_fixed(nb) + J = T.sparse_variable((mb, nb + 1, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnzb * blk * blk, "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), mb * blk * feat_size, "float32") + C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + + with T.iter([T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)], "SRSSS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle) -> None: + nb = T.var("int32") + mb = T.var("int32") + feat_size = T.var("int32") + nnz = T.var("int32") + col = T.var("int32") + blk = T.var("int32") + I = T.dense_fixed(nb) + J = T.sparse_fixed((mb, nnz, col), indices, "int32") + F = T.dense_fixed(feat_size) + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnz * blk * blk, "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), mb * blk * feat_size, "float32") + C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + + with T.iter([T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)], "SRSSS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def batch_mm( + a: T.handle, + b: T.handle, + c: T.handle, + i_indptr: T.handle, + j_a_indptr: T.handle, + j_b_indptr: T.handle, + k_b_indptr: T.handle, + k_c_indptr: T.handle, +): + batch = T.var("int32") + n_max = T.var("int32") + m_max = T.var("int32") + k_max = T.var("int32") + nnz_ac1 = T.var("int32") + nnz_b1 = T.var("int32") + nnz_a2 = T.var("int32") + nnz_b2 = T.var("int32") + nnz_c2 = T.var("int32") + + Batch = T.dense_fixed(batch) + I = T.dense_variable((n_max, batch + 1), i_indptr, "int32") + J_a = T.dense_variable((m_max, nnz_ac1 + 1), j_a_indptr, "int32") + J_b = T.dense_variable((m_max, batch + 1), j_b_indptr, "int32") + K_b = T.dense_variable((k_max, nnz_b1 + 1), k_b_indptr, "int32") + K_c = T.dense_variable((k_max, nnz_ac1 + 1), k_c_indptr, "int32") + A = T.match_sparse_buffer(a, (Batch, I, J_a), nnz_a2, "float32") + B = T.match_sparse_buffer(b, (Batch, J_b, K_b), nnz_b2, "float32") + C = T.match_sparse_buffer(c, (Batch, I, K_c), nnz_c2, "float32") + + with T.iter([T.cord(Batch), T.cord(I), T.cord(K_b), T.cord(J_a)], "SSSR", "batch_mm") as [ + vb, + vi, + vk, + vj, + ]: + with T.init(): + C[vb, vi, vk] = 0.0 + C[vb, vi, vk] = C[vb, vi, vk] + A[vb, vi, vj] * B[vb, vj, vk] + + +@T.prim_func +def csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle): + m = T.var("int32") + n = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(m) + J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (I, J), nnz, "float32") + + with T.iter([T.cord(I), T.pos(J)], "SS", "csr_element_wise") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +def test_csrmm(): + func = csrmm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_csr_reduce(): + func = csr_reduce + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_bsrmm(): + func = bsrmm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_ellpack_mm(): + func = ellpack_mm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_batch_mm(): + func = batch_mm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_csr_element_wise(): + func = csr_element_wise + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +if __name__ == "__main__": + test_csrmm() + test_csr_reduce() + test_bsrmm() + test_ellpack_mm() + test_batch_mm() + test_csr_element_wise() diff --git a/tests/python/unittest/test_tir_sparse_scripts.py b/tests/python/unittest/test_tir_sparse_scripts.py deleted file mode 100644 index 4a80f21164a0..000000000000 --- a/tests/python/unittest/test_tir_sparse_scripts.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -import tvm.tir as tir -import tvm.te as te -from tvm.script import tir as T - - -@T.prim_func -def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: - n = T.var("int32") - m = T.var("int32") - k = T.var("int32") - nnz = T.var("int32") - I = T.dense_fixed("I", n, "int32") - J = T.sparse_variable("J", (m, nnz), (indptr, indices), "int32") - K = T.dense_fixed("K", k, "int32") - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (T.to_dense(J), K), "float32") - C = T.match_sparse_buffer(c, (I, K), "float32") - with T.iter((T.cord(I), T.cord(J), T.cord(K)), "SRS", "csrmm") as [vi, vj, vk]: - with T.init(): - C[vi, vk] = 0. - C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] - - -@T.prim_func -def csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle) -> None: - n = T.var("int32") - m = T.var("int32") - nnz = T.var("int32") - I = T.dense_fixed("I", n, "int32") - J = T.sparse_variable("J", (m, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (I,), "float32") - with T.iter((tir.cord(I), tir.pos(J)), "SR", "csr_reduce") as [vi, vj]: - with T.init(): - B[vi] = 0. - B[vi] = B[vi] + A[vi, vj] - - -@T.prim_func -def bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: - nb = T.var("int32") - mb = T.var("int32") - nnzb = T.var("int32") - blk = T.var("int32") - feat_size = T.var("int32") - I = T.dense_fixed("I", nb, "int32") - J = T.sparse_variable("J", (mb, nnzb), (indptr, indices), "int32") - BI = T.dense_fixed("BI", blk, "int32") - BJ = T.dense_fixed("BJ", blk, "int32") - F = T.dense_fixed("F", feat_size, "int32") - A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") - B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), "float32") - C = T.match_sparse_buffer(c, (I, BI, F), "float32") - - with T.iter((T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)), "SRSSS", "bsrmm") as [vi, vj, vbi, vbj, vf]: - with T.init(): - C[vi, vbi, vf] = 0. - C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] - - -def test_csrmm(): - pass - - -def test_csr_reduce(): - pass - - -def test_bsrmm(): - pass - - -if __name__ == "__main__": - test_csrmm() - test_csr_reduce() - test_bsrmm() - - - \ No newline at end of file