Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fatal bugfix and change the signature of DenseVariableAxis. #33

Merged
merged 1 commit into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class AxisNode : public Object {
DataType GetIndexType() const { return length->dtype; }

virtual AxisKind kind() const = 0;
virtual PrimExpr nnz() const = 0;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -134,6 +135,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
public:
AxisKind kind() const final { return AxisKind::kDenseFixed; }

PrimExpr nnz() const final { return length; }

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -234,6 +237,7 @@ class FusedAxis : public DenseFixedAxis {
class DenseVariableAxisNode : public DenseAxisNode {
public:
Buffer indptr;
PrimExpr nnz_;

void VisitAttrs(AttrVisitor* v) {
DenseAxisNode::VisitAttrs(v);
Expand All @@ -249,10 +253,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
hash_reduce(indptr);
}

PrimExpr nnz() const { return indptr->shape[0]; }

AxisKind kind() const final { return AxisKind::kDenseVariable; }

PrimExpr nnz() const final { return nnz_; }

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand All @@ -263,7 +267,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};
Expand All @@ -289,11 +293,13 @@ class SparseFixedAxisNode : public SparseAxisNode {
}

void SHashReduce(SHashReducer hash_reduce) const {
SparseFixedAxisNode::SHashReduce(hash_reduce);
SparseAxisNode::SHashReduce(hash_reduce);
hash_reduce(indices);
hash_reduce(nnz_cols);
}

PrimExpr nnz() const { return indices->shape[0]; }

AxisKind kind() const final { return AxisKind::kSparseFixed; }

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
Expand Down Expand Up @@ -336,7 +342,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
hash_reduce(indices);
}

PrimExpr nnz() const { return indptr->shape[0]; }
PrimExpr nnz() const { return indices->shape[0]; }

AxisKind kind() const final { return AxisKind::kSparseVariable; }

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,11 +885,11 @@ def dense_variable(
f"`dense_variable` expected assign to only one var, but got {names}", span
)

length, indptr_len = shape
length, indptr_len, nnz = shape
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = DenseVariableAxis(names[0], length, indptr_buf)
axis = DenseVariableAxis(names[0], length, nnz, indptr_buf)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ class DenseVariableAxis(DenseAxis):

name: str
length: PrimExpr
nnz: PrimExpr
indptr: Buffer

def __init__(self, name, length, indptr):
def __init__(self, name, length, nnz, indptr):
self.__init_handle_by_constructor__(
_ffi_api.DenseVariableAxis, name, length, indptr # type: ignore
_ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore
)


Expand Down
29 changes: 10 additions & 19 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr) {
return DenseVariableAxis(name, length, indptr);
.set_body_typed([](String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
return DenseVariableAxis(std::move(name), std::move(length), std::move(nnz), std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -128,17 +129,7 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
fused_name += group[i]->name;
}
node->name = "fused_" + fused_name + "_" + group[index]->name;

if (const auto* df_axis = group[index].as<DenseFixedAxisNode>()) {
node->length = df_axis->length;
} else if (const auto* sf_axis = group[index].as<SparseFixedAxisNode>()) {
// TODO(zihao): accumulate previous dimensions.
} else if (const auto* dv_axis = group[index].as<DenseVariableAxisNode>()) {
node->length = dv_axis->nnz();
} else if (const auto* sv_axis = group[index].as<SparseVariableAxisNode>()) {
node->length = sv_axis->nnz();
}

node->length = group[index]->nnz();
node->is_derived_axis = true;
node->group = std::move(group);
node->index = index;
Expand Down Expand Up @@ -183,7 +174,7 @@ TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) {
return SparseFixedAxis(name, length, indices, nnz_cols);
return SparseFixedAxis(std::move(name), std::move(length), std::move(indices), std::move(nnz_cols));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -210,7 +201,7 @@ TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
return SparseVariableAxis(name, length, indptr, indices);
return SparseVariableAxis(std::move(name), std::move(length), std::move(indptr), std::move(indices));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -259,7 +250,7 @@ TVM_REGISTER_NODE_TYPE(AxisTreeNode);

TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
.set_body_typed([](Array<String> axis_names, Array<Optional<String>> axis_parent_names) {
return AxisTree(axis_names, axis_parent_names);
return AxisTree(std::move(axis_names), std::move(axis_parent_names));
});

/******** SparseBuffer ********/
Expand All @@ -279,7 +270,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
.set_body_typed([](Array<Axis> axes, Buffer data, String name) {
return SparseBuffer(axes, data, name);
return SparseBuffer(std::move(axes), std::move(data), std::move(name));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -338,7 +329,7 @@ TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) {
return SpIterVar(var, max_extent, is_reduction, axis);
return SpIterVar(std::move(var), std::move(max_extent), is_reduction, std::move(axis));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down