diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 6a6b368a70ad..18d2c729ed06 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -44,7 +44,10 @@ class AxisNode : public Object { /* length of current axis. For sparse axis, length refers to the upperbound of * the current axis. */ PrimExpr length; + static constexpr const char* _type_key = "tir.sparse.Axis"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(AxisNode, Object); }; @@ -98,6 +101,20 @@ class DenseAxis : public Axis { */ class DenseFixedAxisNode : public DenseAxisNode { public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("length", &length); + } + + bool SEqualReduce(const DenseAxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(length); + } + static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); }; @@ -108,12 +125,31 @@ class DenseFixedAxisNode : public DenseAxisNode { */ class DenseFixedAxis : public DenseAxis { public: + TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length); + TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode); }; class DenseVariableAxisNode : public DenseAxisNode { public: Buffer indptr; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("length", &length); + v->Visit("indptr", &indptr); + } + + bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(length); + hash_reduce(indptr); + } + static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); }; @@ -124,8 +160,9 @@ class DenseVariableAxisNode : public DenseAxisNode { */ class DenseVariableAxis : public DenseAxis { public: - TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, - DenseVariableAxisNode); + TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr); + + TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); }; /*! @@ -154,6 +191,26 @@ class SparseFixedAxisNode : public SparseAxisNode { Buffer indices; /* fixed number of columns of current sparse axis. */ PrimExpr num_cols; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("length", &length); + v->Visit("indptr", &indices); + v->Visit("num_cols", &num_cols); + } + + bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length) && + equal(indices, other->indices) && equal(num_cols, other->num_cols); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(length); + hash_reduce(indices); + hash_reduce(num_cols); + } + static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); }; @@ -164,8 +221,9 @@ class SparseFixedAxisNode : public SparseAxisNode { */ class SparseFixedAxis : public SparseAxis { public: - TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, - SparseFixedAxisNode); + TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode); }; /*! @@ -173,8 +231,29 @@ class SparseFixedAxis : public SparseAxis { */ class SparseVariableAxisNode : public SparseAxisNode { public: - Buffer indptr, indices; - static constexpr const char* _type_key = "tir.sparse.SparseVariabledAxis"; + Buffer indptr; + Buffer indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("length", &length); + v->Visit("indptr", &indptr); + v->Visit("indices", &indices); + } + + bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length) && + equal(indptr, other->indptr) && equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(length); + hash_reduce(indptr); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); }; @@ -184,8 +263,9 @@ class SparseVariableAxisNode : public SparseAxisNode { */ class SparseVariableAxis : public SparseAxis { public: - TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, - SparseVariableAxisNode); + TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode); }; /*! @@ -223,6 +303,26 @@ class SparseBufferNode : public Object { int ndim; /* Buffer corresponding to flattened value */ Buffer data; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &root); + v->Visit("length", &axes); + v->Visit("indptr", &ndim); + v->Visit("num_cols", &data); + } + + bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { + return equal(root, other->root) && equal(axes, other->axes) && equal(ndim, other->ndim) && + equal(data, other->data); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(root); + hash_reduce(axes); + hash_reduce(ndim); + hash_reduce(data); + } + static constexpr const char* _type_key = "tir.sparse.SparseBufferNode"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object); }; @@ -233,6 +333,8 @@ class SparseBufferNode : public Object { */ class SparseBuffer : public ObjectRef { public: + TVM_DLL explicit SparseBuffer(AxisTree root, Array axes, int ndim, Buffer data); + TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); }; @@ -240,4 +342,4 @@ class SparseBuffer : public ObjectRef { } // namespace tir } // namespace tvm -#endif // TVM_TIR_BUFFER_H_ +#endif // TVM_TIR_SPARSE_H_ diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py new file mode 100644 index 000000000000..cc79ea628b87 --- /dev/null +++ b/python/tvm/tir/sparse.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SparseTIR axes and SparseBuffer +""" +from typing import List +import tvm._ffi +from tvm.ir import PrimExpr +from tvm.runtime import Object, const + +from . import _ffi_api +from .buffer import Buffer + + +class Axis(Object): + """Base class of all the sparse axes.""" + + +class DenseAxis(Axis): + pass + + +class SparseAxis(Axis): + pass + + +@tvm._ffi.register_object("tir.sparse.DenseFixedAxis") +class DenseFixedAxis(DenseAxis): + """DenseFixedAxis node + + Parameters + ---------- + name : str + The name of the axis + + length : PrimExpr + The length of the axis + """ + + name: str + length: PrimExpr + + def __init__(self, name, length): + self.__init_handle_by_constructor__( + _ffi_api.DenseFixedAxis, name, length # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.DenseVariableAxis") +class DenseVariableAxis(DenseAxis): + """DenseVariableAxis node + + Parameters + ---------- + name : str + The name of the axis + + length : PrimExpr + The length of the axis + + indptr : Buffer + The indptr buffer of the axis + """ + + name: str + length: PrimExpr + indptr: Buffer + + def __init__(self, name, length, indptr): + self.__init_handle_by_constructor__( + _ffi_api.DenseVariableAxis, name, length, indptr # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.SparseFixedAxis") +class SparseFixedAxis(DenseAxis): + """SparseFixedAxis node + + Parameters + ---------- + name : str + The name of the axis + + length : PrimExpr + The length of the axis + + indices : Buffer + The indices buffer of the axis + + num_cols : PrimExpr + The number of non-zero elements along the axis + """ + + name: str + length: PrimExpr + indices: Buffer + num_cols: PrimExpr + + def __init__(self, name, length, indices, num_cols): + self.__init_handle_by_constructor__( + _ffi_api.SparseFixedAxis, name, length, indices, num_cols # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.SparseVariableAxis") +class SparseVariableAxis(DenseAxis): + """SparseVariableAxis node + + Parameters + ---------- + name : str + The name of the axis + + length : PrimExpr + The length of the axis + + indptr : Buffer + The indptr buffer of the axis + + indices : Buffer + The indices buffer of the axis + """ + + name: str + length: PrimExpr + indptr: Buffer + indices: Buffer + + def __init__(self, name, length, indptr, indices): + self.__init_handle_by_constructor__( + _ffi_api.SparseVariableAxis, name, length, indptr, indices # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.AxisTree") +class AxisTree: + # Todo(@ruihang): to do later + pass + + +@tvm._ffi.register_object("tir.sparse.SparseBuffer") +class SparseBuffer: + """SparseBuffer node + + Parameters + ---------- + root : AxisTree + The root of the axis dependency tree of the sparse buffer + + axes : List[Axis] + The axes of the sparse buffer + + ndim : int + The number of dimensions of the sparse buffer + + data : Buffer + The data of the sparse buffer + """ + + root: AxisTree + axes: List[Axis] + ndim: int + data: Buffer + + def __init__(self, root, axes, ndim, data): + self.__init_handle_by_constructor__( + _ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore + ) diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 984833597d1e..7529d1e1ac5b 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -21,12 +21,99 @@ * \file sparse.cc * \brief buffers and formats in sparse tir. */ +#include #include +#include namespace tvm { namespace tir { -// TODO(zihao/ruihang) +namespace sparse { + +// DenseFixedAxis +DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->length = std::move(length); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { + return DenseFixedAxis(name, length); +}); + +// DenseVariableAxis +DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->length = std::move(length); + node->indptr = std::move(indptr); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") + .set_body_typed([](String name, PrimExpr length, Buffer indptr) { + return DenseVariableAxis(name, length, indptr); + }); + +// SparseFixedAxis +SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->length = std::move(length); + node->indices = std::move(indices); + node->num_cols = std::move(num_cols); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") + .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { + return SparseFixedAxis(name, length, indices, num_cols); + }); + +// SparseVariableAxis +SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, + Buffer indices) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->length = std::move(length); + node->indptr = std::move(indptr); + node->indices = std::move(indices); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") + .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) { + return SparseVariableAxis(name, length, indptr, indices); + }); + +// SparseBuffer +SparseBuffer::SparseBuffer(AxisTree root, Array axes, int ndim, Buffer data) { + ObjectPtr node = make_object(); + node->root = std::move(root); + node->axes = std::move(axes); + node->ndim = ndim; + node->data = std::move(data); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(SparseBufferNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") + .set_body_typed([](AxisTree root, Array axes, int ndim, Buffer data) { + // Todo(@ruihang): to be revised later + return SparseBuffer(root, axes, ndim, data); + }); + +} // namespace sparse } // namespace tir } // namespace tvm