Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syntax of AttachAxis for BMM #36

Merged
merged 3 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,28 @@ class AxisNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("is_derived_axis", &is_derived_axis);
}

bool SEqualReduce(const AxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(is_derived_axis, other->is_derived_axis);
return equal(name, other->name) && equal(length, other->length);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(is_derived_axis);
}

/* name of current axis. */
String name;
/* length of current axis. For sparse axis, length refers to the upperbound of
* the current axis. */
PrimExpr length;
/* indicates whether current axis is derived by dense(axis) or fuse(axis1, axis2, ...) */
bool is_derived_axis = false;

String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }
virtual Optional<Axis> GetParentAxis() const = 0;
Axis GetRootAxis() const;

virtual AxisKind kind() const = 0;
virtual PrimExpr nnz() const = 0;
Expand Down Expand Up @@ -266,7 +262,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
Optional<Axis> GetParentAxis() const final { return parent_; }

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};

/*!
Expand All @@ -281,6 +277,30 @@ class DenseVariableAxis : public DenseAxis {
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};

/*!
* \brief Dense variable axis attached to another dense variable axis.
*/
class AttachedAxisNode : public DenseVariableAxisNode {
public:
/* The original axis before attaching. */
Axis orig_;

Axis GetOriginalAxis() const { return orig_; }

static constexpr const char* _type_key = "tir.sparse.AttachedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode);
};

/*!
* \brief Managed reference to AttachedAxisNode.
* \sa AttachedAxisNode
*/
class AttachedAxis : public DenseVariableAxis {
public:
TVM_DLL explicit AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr);
TVM_DEFINE_OBJECT_REF_METHODS(AttachedAxis, DenseVariableAxis, AttachedAxisNode);
};

/*!
* \brief Sparse axis with fixed number of non-zero columns per row.
*/
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""TVM Script Parser Special Stmt Classes"""
# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
# pylint: disable=relative-beyond-top-level
from os import name
from typing import Callable, List, Optional, Tuple, Any, Mapping, Union

import synr
Expand All @@ -34,6 +35,7 @@
DenseVariableAxis,
SparseFixedAxis,
SparseVariableAxis,
AttachedAxis,
)

from .node import BufferSlice
Expand Down Expand Up @@ -900,6 +902,38 @@ def dense_variable(
super().__init__(dense_variable, def_symbol=True)


@register
class Attach(SpecialStmt):
"""Special Stmt for attaching axis."""

def __init__(self):
def attach_axis(
parent: Axis,
orig: Axis,
nnz: PrimExpr,
indptr_var: tvm.tir.Var,
idtype: str = "int32",
span: Optional[Span] = None,
):
names = [x.id.name for x in self.node.lhs]
if len(names) != 1:
self.context.report_error(
f"`attach_axis` expected assign to only one var, but got {names}", span
)

indptr_len = orig.nnz + 1
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = AttachedAxis(names[0], parent, orig, nnz, indptr_buf)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)

super().__init__(attach_axis, def_symbol=True)


@register
class SparseFixed(SpecialStmt):
"""Special Stmt for creating sparse fixed axis."""
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,36 @@ def __init__(self, name, parent, length, nnz, indptr):
)


@tvm._ffi.register_object("tir.sparse.AttachedAxis")
class AttachedAxis(DenseVariableAxis):
"""AttachedAxis node

Parameters
----------
name : str
The name of the axis.
parent : Axis
The axis to attach to.
orig : Axis
The axis to be attached.
nnz : PrimExpr
The number of nonzeros of the returned axis.
indptr : PrimExpr
The new indptr array of the the returned axis.
"""

name : str
parent : Axis
orig : Axis
nnz : PrimExpr
indptr : PrimExpr

def __init__(self, name, parent, length, nnz, indptr):
self.__init_handle_by_constructor__(
_ffi_api.AttachedAxis, name, parent, length, nnz, indptr
)


@tvm._ffi.register_object("tir.sparse.SparseFixedAxis")
class SparseFixedAxis(DenseAxis):
"""SparseFixedAxis node
Expand Down
50 changes: 26 additions & 24 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,12 @@ Doc TVMScriptPrinter::AllocAxis(const Axis& axis) {
return it->second;
}
Doc val;
const auto* df_axis = axis.as<DenseFixedAxisNode>();

if (df_axis != nullptr && df_axis->is_derived_axis) {
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")");
} else {
CHECK(false) << "Cannot allocate fused axis";
}
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
// DenseFromSparseAxis is a temporally defined axis.
val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")");
} else if (axis.as<FusedAxisNode>()) {
// FusedAxis is also a temporally defined axis.
CHECK(false) << "Cannot allocate fused axis";
} else {
std::string name = axis->name;
if (name.length() == 0 || !std::isalnum(name[0])) {
Expand Down Expand Up @@ -1242,19 +1240,16 @@ Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) {
Doc iter_doc;

std::string axis_repr = sp_iter->axis->name;
if (axis->is_derived_axis) {
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")";
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")";
} else if (const FusedAxisNode* fused_axis = axis.as<FusedAxisNode>()) {
std::string orig_axis_name = fused_axis->group[fused_axis->index]->name;
if (fused_axis->index == 0) {
iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name;
} else if (fused_axis->index == int(fused_axis->group.size() - 1)) {
iter_doc << orig_axis_name << ")";
} else {
const FusedAxisNode* fused_axis = axis.as<FusedAxisNode>();
std::string orig_axis_name = fused_axis->group[fused_axis->index]->name;
if (fused_axis->index == 0) {
iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name;
} else if (fused_axis->index == fused_axis->group.size() - 1) {
iter_doc << orig_axis_name << ")";
} else {
iter_doc << orig_axis_name;
}
iter_doc << orig_axis_name;
}
} else {
iter_doc << axis->name;
Expand Down Expand Up @@ -1327,10 +1322,17 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
ICHECK_EQ(params.size(), 0);
doc << "dense_fixed(" << Print(df_axis->length) << ")";
} else if (const auto* dv_axis = obj.as<DenseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length) << ", "
<< Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
if (const auto* attached_axis = obj.as<AttachedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "attach_axis(" << attached_axis->parent_->name << ", " << attached_axis->orig_->name
<< ", " << Print(attached_axis->nnz()) << ", " << Print(params[0]) << ", "
<< PrintDType(attached_axis->indptr->dtype) << ")";
} else {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length)
<< ", " << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
}
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "sparse_fixed(" << sf_axis->parent_->name << ", (" << Print(sf_axis->length) << ", "
Expand Down
100 changes: 69 additions & 31 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis)

TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->nnz(); });

/******** AxisNode ********/

/*! \brief Implementation of get root axis function. */
Axis AxisNode::GetRootAxis() const {
Optional<Axis> parent = GetParentAxis();
if (parent.defined()) {
return parent.value()->GetRootAxis();
} else {
return GetRef<Axis>(this);
}
}

/******** DenseFixedAxis ********/

/*! \brief Default constructor of DenseFixedAxis */
Expand All @@ -67,43 +79,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "dense_fixed(" << op->name << ", " << op->length << ")";
});

/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
node->length = std::move(length);
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
.set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) {
return DenseVariableAxis(std::move(name), std::move(parent), std::move(length),
std::move(nnz), std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DenseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DenseVariableAxisNode*>(node.get());
p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name
<< ")";
});

/******** DenseFromSparseAxis ********/

/*! \brief Default constructor of DenseFromSparseAxis */
DenseFromSparseAxis::DenseFromSparseAxis(SparseAxis base) {
ObjectPtr<DenseFromSparseAxisNode> node = make_object<DenseFromSparseAxisNode>();
node->name = base->name + "_dense";
node->length = base->length;
node->is_derived_axis = true;
node->base = std::move(base);
data_ = std::move(node);
}
Expand Down Expand Up @@ -135,7 +117,6 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
}
node->name = "fused_" + fused_name + "_" + group[index]->name;
node->length = group[index]->nnz();
node->is_derived_axis = true;
node->group = std::move(group);
node->index = index;
data_ = std::move(node);
Expand Down Expand Up @@ -163,6 +144,63 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
node->length = std::move(length);
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
.set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) {
return DenseVariableAxis(std::move(name), std::move(parent), std::move(length),
std::move(nnz), std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DenseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DenseVariableAxisNode*>(node.get());
p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name
<< ")";
});

/******** AttachedAxis ********/
/*! \brief Default constructor of AttachedAxis */
AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) {
ObjectPtr<AttachedAxisNode> node = make_object<AttachedAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
node->orig_ = std::move(orig);
node->length = node->orig_->length;
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(AttachedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis")
.set_body_typed([](String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) {
return AttachedAxis(std::move(name), std::move(parent), std::move(orig), std::move(nnz),
std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttachedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AttachedAxisNode*>(node.get());
p->stream << "attached_axis(" << op->name << ", " << op->length << ", " << op->indptr->name
<< ")";
});

/******** SparseFixedAxis ********/

/*! \brief Default constructor of SparseFixedAxis */
Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ PrimExpr AggregateOffset(PrimExpr prev_offset, Axis axis, PrimExpr index,
break;
}
case AxisKind::kDenseVariable: {
// TODO(zihao): finish the aggregating offset for attached axis.
auto dv_axis = axis.as<DenseVariableAxisNode>();
new_offset = add(BufferLoad(dv_axis->indptr, {std::move(prev_offset)}), std::move(index));
break;
Expand Down
Loading