From cb131f5ee3ff23e49cc582a4241b689e77b27bef Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 23 Oct 2021 13:30:33 +0800 Subject: [PATCH] [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` --- include/tvm/tir/expr.h | 53 +++++++++++++++++++++++++++++++++++++++ include/tvm/tir/sparse.h | 25 ++++++++++++------- include/tvm/tir/stmt.h | 54 ++++++++++++++++++++++++++++++++++++++++ python/tvm/tir/sparse.py | 16 +++++++----- src/tir/ir/expr.cc | 30 ++++++++++++++++++++++ src/tir/ir/sparse.cc | 17 +++++-------- src/tir/ir/stmt.cc | 33 ++++++++++++++++++++++++ 7 files changed, 202 insertions(+), 26 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index f6741112f269..93fb56a4e7c0 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; +/*! + * \brief Load value from the high dimension sparse buffer. + * + * \code + * + * value = buffer[i, j]; + * + * \endcode + * \sa SparseBufferStore + */ +class SparseBufferLoadNode : public PrimExprNode { + public: + /*! \brief The buffer variable. */ + SparseBuffer buffer; + /*! \brief The indices location to be loaded. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &(this->dtype)); + v->Visit("buffer", &buffer); + v->Visit("indices", &indices); + v->Visit("span", &span); + } + + bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(buffer, other->buffer) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(buffer); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "tir.SparseBufferLoad"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode); +}; + +/*! + * \brief Managed reference to SparseBufferLoadNode. + * \sa SparseBufferLoadNode + */ +class SparseBufferLoad : public PrimExpr { + public: + TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array indices, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode); +}; + /*! * \brief Load value from the result produced by the producer. * diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index afc6f7723a18..e184dc050856 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -32,7 +32,6 @@ namespace tvm { namespace tir { -namespace sparse { /*! * \brief Base type for axis in sparse formats. @@ -308,28 +307,36 @@ class SparseBufferNode : public Object { AxisTree tree; /* Axes */ Array axes; - /* Number of dimensions */ - int ndim; /* Buffer corresponding to flattened value */ Buffer data; + /* Buffer Name */ + String name; + /* Data type */ + runtime::DataType dtype; + + inline int ndim() const { + return static_cast(axes.size()); + } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &tree); v->Visit("length", &axes); - v->Visit("indptr", &ndim); v->Visit("num_cols", &data); + v->Visit("name", &name); + v->Visit("dtype", &dtype); } bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { - return equal(tree, other->tree) && equal(axes, other->axes) && equal(ndim, other->ndim) && - equal(data, other->data); + return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) && + equal(name, other->name) && equal(dtype, other->dtype); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(tree); hash_reduce(axes); - hash_reduce(ndim); hash_reduce(data); + hash_reduce(name); + hash_reduce(dtype); } static constexpr const char* _type_key = "tir.sparse.SparseBuffer"; @@ -342,12 +349,12 @@ class SparseBufferNode : public Object { */ class SparseBuffer : public ObjectRef { public: - TVM_DLL explicit SparseBuffer(AxisTree tree, Array axes, int ndim, Buffer data); + TVM_DLL explicit SparseBuffer(AxisTree tree, Array axes, Buffer data, String name, + DataType dtype); TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); }; -} // namespace sparse } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4f5772822d9e..c776dbb28ef5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -327,6 +327,60 @@ class BufferStore : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); }; +/*! + * \brief Store value to the high dimension sparse buffer. + * + * \code + * + * buffer[i, j] = value; + * + * \endcode + * \sa SparseBufferLoad + */ +class SparseBufferStoreNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + SparseBuffer buffer; + /*! \brief The value to be stored. */ + PrimExpr value; + /*! \brief The indices location to be stored. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("value", &value); + v->Visit("indices", &indices); + v->Visit("span", &span); + } + + bool SEqualReduce(const SparseBufferStoreNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(value, other->value) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(value); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "tir.SparseBufferStore"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferStoreNode, StmtNode); +}; + +/*! + * \brief Managed reference to SparseBufferStoreNode. + * \sa SparseBufferStoreNode + */ +class SparseBufferStore : public Stmt { + public: + TVM_DLL explicit SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array indices, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferStore, Stmt, SparseBufferStoreNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferStoreNode); +}; + /*! * \brief Annotate the region where the buffer need to * be read and write in the body. diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index ed5d695da9c5..4ec289aa70ed 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -177,19 +177,23 @@ class SparseBuffer: axes : List[Axis] The axes of the sparse buffer - ndim : int - The number of dimensions of the sparse buffer - data : Buffer The data of the sparse buffer + + name : str + The name of the sparse buffer + + dtype : Optional[str] + The data type of the sparse buffer """ tree: AxisTree axes: List[Axis] - ndim: int data: Buffer + name: str - def __init__(self, tree, axes, ndim, data): + def __init__(self, tree, axes, data, name, dtype=None): + dtype = "float32" if dtype is None else dtype self.__init_handle_by_constructor__( - _ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore + _ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore ) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1d7c959d990d..c3cfc4aff326 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1084,6 +1084,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "]"; }); +// SparseBufferLoad +SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array indices, Span span) { + ObjectPtr node = make_object(); + node->dtype = buffer->dtype; + node->buffer = std::move(buffer); + node->indices = std::move(indices); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SparseBufferLoad") + .set_body_typed([](SparseBuffer buffer, Array indices, Span span) { + return SparseBufferLoad(buffer, indices, span); + }); + +TVM_REGISTER_NODE_TYPE(SparseBufferLoadNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); + // ProducerLoad ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span span) { ObjectPtr node = make_object(); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 108295a0b13b..f8519865666c 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -28,9 +28,6 @@ namespace tvm { namespace tir { -namespace sparse { - - // DenseFixedAxis DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { ObjectPtr node = make_object(); @@ -148,25 +145,23 @@ TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") }); // SparseBuffer -SparseBuffer::SparseBuffer(AxisTree tree, Array axes, int ndim, - Buffer data) { +SparseBuffer::SparseBuffer(AxisTree tree, Array axes, Buffer data, String name, + DataType dtype) { ObjectPtr node = make_object(); node->tree = std::move(tree); node->axes = std::move(axes); - node->ndim = ndim; node->data = std::move(data); + node->name = std::move(name); + node->dtype = dtype; data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(SparseBufferNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") - .set_body_typed([](AxisTree root, Array axes, int ndim, Buffer data) { - // Todo(@ruihang): to be revised later - return SparseBuffer(root, axes, ndim, data); + .set_body_typed([](AxisTree tree, Array axes, Buffer data, String name, DataType dtype) { + return SparseBuffer(tree, axes, data, name, dtype); }); -} // namespace sparse - } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0d42c20c2822..7bd135a72aa9 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -618,6 +618,39 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +// SparseBufferStore +SparseBufferStore::SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array indices, + Span span) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->value = std::move(value); + node->indices = std::move(indices); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SparseBufferStore") + .set_body_typed([](SparseBuffer buffer, PrimExpr value, Array indices, Span span) { + return SparseBufferStore(buffer, value, indices, span); + }); + +TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); + // BufferRealize BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span) {