Skip to content

Commit

Permalink
[SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)
Browse files Browse the repository at this point in the history
* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`
  • Loading branch information
MasterJH5574 authored and yzh119 committed Nov 3, 2021
1 parent ef197bc commit 8e70d89
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 26 deletions.
53 changes: 53 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/sparse.h>
#include <tvm/tir/var.h>

#include <algorithm>
Expand Down Expand Up @@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
* \brief Load value from the high dimension sparse buffer.
*
* \code
*
* value = buffer[i, j];
*
* \endcode
* \sa SparseBufferStore
*/
class SparseBufferLoadNode : public PrimExprNode {
public:
/*! \brief The buffer variable. */
SparseBuffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.SparseBufferLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to SparseBufferLoadNode.
* \sa SparseBufferLoadNode
*/
class SparseBufferLoad : public PrimExpr {
public:
TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode);
};

/*!
* \brief Load value from the result produced by the producer.
*
Expand Down
25 changes: 16 additions & 9 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

namespace tvm {
namespace tir {
namespace sparse {

/*!
* \brief Base type for axis in sparse formats.
Expand Down Expand Up @@ -308,28 +307,36 @@ class SparseBufferNode : public Object {
AxisTree tree;
/* Axes */
Array<Axis> axes;
/* Number of dimensions */
int ndim;
/* Buffer corresponding to flattened value */
Buffer data;
/* Buffer Name */
String name;
/* Data type */
runtime::DataType dtype;

inline int ndim() const {
return static_cast<int>(axes.size());
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &tree);
v->Visit("length", &axes);
v->Visit("indptr", &ndim);
v->Visit("num_cols", &data);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(tree, other->tree) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
equal(data, other->data);
return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name) && equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(tree);
hash_reduce(axes);
hash_reduce(ndim);
hash_reduce(data);
hash_reduce(name);
hash_reduce(dtype);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
Expand All @@ -342,12 +349,12 @@ class SparseBufferNode : public Object {
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim, Buffer data);
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};

} // namespace sparse
} // namespace tir
} // namespace tvm

Expand Down
54 changes: 54 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,60 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
* \brief Store value to the high dimension sparse buffer.
*
* \code
*
* buffer[i, j] = value;
*
* \endcode
* \sa SparseBufferLoad
*/
class SparseBufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
SparseBuffer buffer;
/*! \brief The value to be stored. */
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const SparseBufferStoreNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(value, other->value) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.SparseBufferStore";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferStoreNode, StmtNode);
};

/*!
* \brief Managed reference to SparseBufferStoreNode.
* \sa SparseBufferStoreNode
*/
class SparseBufferStore : public Stmt {
public:
TVM_DLL explicit SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferStore, Stmt, SparseBufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferStoreNode);
};

/*!
* \brief Annotate the region where the buffer need to
* be read and write in the body.
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,23 @@ class SparseBuffer:
axes : List[Axis]
The axes of the sparse buffer
ndim : int
The number of dimensions of the sparse buffer
data : Buffer
The data of the sparse buffer
name : str
The name of the sparse buffer
dtype : Optional[str]
The data type of the sparse buffer
"""

tree: AxisTree
axes: List[Axis]
ndim: int
data: Buffer
name: str

def __init__(self, tree, axes, ndim, data):
def __init__(self, tree, axes, data, name, dtype=None):
dtype = "float32" if dtype is None else dtype
self.__init_handle_by_constructor__(
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
_ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore
)
30 changes: 30 additions & 0 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "]";
});

// SparseBufferLoad
SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices, Span span) {
ObjectPtr<SparseBufferLoadNode> node = make_object<SparseBufferLoadNode>();
node->dtype = buffer->dtype;
node->buffer = std::move(buffer);
node->indices = std::move(indices);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.SparseBufferLoad")
.set_body_typed([](SparseBuffer buffer, Array<PrimExpr> indices, Span span) {
return SparseBufferLoad(buffer, indices, span);
});

TVM_REGISTER_NODE_TYPE(SparseBufferLoadNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SparseBufferLoadNode*>(node.get());
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) {
p->stream << ", ";
}
}
p->stream << "]";
});

// ProducerLoad
ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span) {
ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();
Expand Down
17 changes: 6 additions & 11 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
namespace tvm {
namespace tir {

namespace sparse {


// DenseFixedAxis
DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
ObjectPtr<DenseFixedAxisNode> node = make_object<DenseFixedAxisNode>();
Expand Down Expand Up @@ -148,25 +145,23 @@ TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
});

// SparseBuffer
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim,
Buffer data) {
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
DataType dtype) {
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
node->tree = std::move(tree);
node->axes = std::move(axes);
node->ndim = ndim;
node->data = std::move(data);
node->name = std::move(name);
node->dtype = dtype;
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(SparseBufferNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
.set_body_typed([](AxisTree root, Array<Axis> axes, int ndim, Buffer data) {
// Todo(@ruihang): to be revised later
return SparseBuffer(root, axes, ndim, data);
.set_body_typed([](AxisTree tree, Array<Axis> axes, Buffer data, String name, DataType dtype) {
return SparseBuffer(tree, axes, data, name, dtype);
});

} // namespace sparse

} // namespace tir
} // namespace tvm
33 changes: 33 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,39 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << '\n';
});

// SparseBufferStore
SparseBufferStore::SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span) {
ObjectPtr<SparseBufferStoreNode> node = make_object<SparseBufferStoreNode>();
node->buffer = std::move(buffer);
node->value = std::move(value);
node->indices = std::move(indices);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.SparseBufferStore")
.set_body_typed([](SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
return SparseBufferStore(buffer, value, indices, span);
});

TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) p->stream << ", ";
}
p->stream << "]";
p->stream << " = ";
p->Print(op->value);
p->stream << '\n';
});

// BufferRealize
BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
Span span) {
Expand Down

0 comments on commit 8e70d89

Please sign in to comment.