From b4b5a4616da6aaea161effc23e3b2bb193db3a4b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 14 Dec 2021 09:36:18 -0800 Subject: [PATCH] fix --- include/tvm/tir/sparse.h | 16 ++++++++++----- python/tvm/script/tir/special_stmt.py | 4 ++-- python/tvm/tir/sparse.py | 5 +++-- src/tir/ir/sparse.cc | 29 +++++++++------------------ 4 files changed, 26 insertions(+), 28 deletions(-) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index c083bb1e3efb..cd0fca704871 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -75,6 +75,7 @@ class AxisNode : public Object { DataType GetIndexType() const { return length->dtype; } virtual AxisKind kind() const = 0; + virtual PrimExpr nnz() const = 0; static constexpr const char* _type_key = "tir.sparse.Axis"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -134,6 +135,8 @@ class DenseFixedAxisNode : public DenseAxisNode { public: AxisKind kind() const final { return AxisKind::kDenseFixed; } + PrimExpr nnz() const final { return length; } + static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); }; @@ -234,6 +237,7 @@ class FusedAxis : public DenseFixedAxis { class DenseVariableAxisNode : public DenseAxisNode { public: Buffer indptr; + PrimExpr nnz_; void VisitAttrs(AttrVisitor* v) { DenseAxisNode::VisitAttrs(v); @@ -249,10 +253,10 @@ class DenseVariableAxisNode : public DenseAxisNode { hash_reduce(indptr); } - PrimExpr nnz() const { return indptr->shape[0]; } - AxisKind kind() const final { return AxisKind::kDenseVariable; } + PrimExpr nnz() const final { return nnz_; } + static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); }; @@ -263,7 +267,7 @@ class DenseVariableAxisNode : public DenseAxisNode { */ class DenseVariableAxis : public DenseAxis { public: - TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr); + TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr); TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); }; @@ -289,11 +293,13 @@ class SparseFixedAxisNode : public SparseAxisNode { } void SHashReduce(SHashReducer hash_reduce) const { - SparseFixedAxisNode::SHashReduce(hash_reduce); + SparseAxisNode::SHashReduce(hash_reduce); hash_reduce(indices); hash_reduce(nnz_cols); } + PrimExpr nnz() const { return indices->shape[0]; } + AxisKind kind() const final { return AxisKind::kSparseFixed; } static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; @@ -336,7 +342,7 @@ class SparseVariableAxisNode : public SparseAxisNode { hash_reduce(indices); } - PrimExpr nnz() const { return indptr->shape[0]; } + PrimExpr nnz() const { return indices->shape[0]; } AxisKind kind() const final { return AxisKind::kSparseVariable; } diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 9a5cb7ef4706..4b02fc734957 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -885,11 +885,11 @@ def dense_variable( f"`dense_variable` expected assign to only one var, but got {names}", span ) - length, indptr_len = shape + length, indptr_len, nnz = shape indptr_buf = tvm.tir.decl_buffer( (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) - axis = DenseVariableAxis(names[0], length, indptr_buf) + axis = DenseVariableAxis(names[0], length, nnz, indptr_buf) self.context.sp_struct.append(axis) self.context.sp_struct_params.append([indptr_var]) self.context.update_symbol(names[0], axis, self.node) diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 7c89f1d56672..9b136d037412 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -127,11 +127,12 @@ class DenseVariableAxis(DenseAxis): name: str length: PrimExpr + nnz: PrimExpr indptr: Buffer - def __init__(self, name, length, indptr): + def __init__(self, name, length, nnz, indptr): self.__init_handle_by_constructor__( - _ffi_api.DenseVariableAxis, name, length, indptr # type: ignore + _ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore ) diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 287d57608fc6..18e7cf8b4f2a 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -68,10 +68,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /******** DenseVariableAxis ********/ /*! \brief Default constuctor of DenseVariableAxis */ -DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { +DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); + node->nnz_ = std::move(nnz); node->indptr = std::move(indptr); data_ = std::move(node); } @@ -79,8 +80,8 @@ DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indptr) { - return DenseVariableAxis(name, length, indptr); + .set_body_typed([](String name, PrimExpr length, PrimExpr nnz, Buffer indptr) { + return DenseVariableAxis(std::move(name), std::move(length), std::move(nnz), std::move(indptr)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -128,17 +129,7 @@ FusedAxis::FusedAxis(Array group, int index) { fused_name += group[i]->name; } node->name = "fused_" + fused_name + "_" + group[index]->name; - - if (const auto* df_axis = group[index].as()) { - node->length = df_axis->length; - } else if (const auto* sf_axis = group[index].as()) { - // TODO(zihao): accumulate previous dimensions. - } else if (const auto* dv_axis = group[index].as()) { - node->length = dv_axis->nnz(); - } else if (const auto* sv_axis = group[index].as()) { - node->length = sv_axis->nnz(); - } - + node->length = group[index]->nnz(); node->is_derived_axis = true; node->group = std::move(group); node->index = index; @@ -183,7 +174,7 @@ TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) { - return SparseFixedAxis(name, length, indices, nnz_cols); + return SparseFixedAxis(std::move(name), std::move(length), std::move(indices), std::move(nnz_cols)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -210,7 +201,7 @@ TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) { - return SparseVariableAxis(name, length, indptr, indices); + return SparseVariableAxis(std::move(name), std::move(length), std::move(indptr), std::move(indices)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -259,7 +250,7 @@ TVM_REGISTER_NODE_TYPE(AxisTreeNode); TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") .set_body_typed([](Array axis_names, Array> axis_parent_names) { - return AxisTree(axis_names, axis_parent_names); + return AxisTree(std::move(axis_names), std::move(axis_parent_names)); }); /******** SparseBuffer ********/ @@ -279,7 +270,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") .set_body_typed([](Array axes, Buffer data, String name) { - return SparseBuffer(axes, data, name); + return SparseBuffer(std::move(axes), std::move(data), std::move(name)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -338,7 +329,7 @@ TVM_REGISTER_NODE_TYPE(SpIterVarNode); TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") .set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) { - return SpIterVar(var, max_extent, is_reduction, axis); + return SpIterVar(std::move(var), std::move(max_extent), is_reduction, std::move(axis)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)