Skip to content

Commit

Permalink
Syntax of AttachAxis for BMM (#36)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd
  • Loading branch information
yzh119 committed Feb 15, 2022
1 parent d755536 commit 184d212
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 62 deletions.
34 changes: 27 additions & 7 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,28 @@ class AxisNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("is_derived_axis", &is_derived_axis);
}

bool SEqualReduce(const AxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(is_derived_axis, other->is_derived_axis);
return equal(name, other->name) && equal(length, other->length);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(is_derived_axis);
}

/* name of current axis. */
String name;
/* length of current axis. For sparse axis, length refers to the upperbound of
* the current axis. */
PrimExpr length;
/* indicates whether current axis is derived by dense(axis) or fuse(axis1, axis2, ...) */
bool is_derived_axis = false;

String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }
virtual Optional<Axis> GetParentAxis() const = 0;
Axis GetRootAxis() const;

virtual AxisKind kind() const = 0;
virtual PrimExpr nnz() const = 0;
Expand Down Expand Up @@ -266,7 +262,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
Optional<Axis> GetParentAxis() const final { return parent_; }

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};

/*!
Expand All @@ -281,6 +277,30 @@ class DenseVariableAxis : public DenseAxis {
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};

/*!
* \brief Dense variable axis attached to another dense variable axis.
*/
class AttachedAxisNode : public DenseVariableAxisNode {
public:
/* The original axis before attaching. */
Axis orig_;

Axis GetOriginalAxis() const { return orig_; }

static constexpr const char* _type_key = "tir.sparse.AttachedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode);
};

/*!
* \brief Managed reference to AttachedAxisNode.
* \sa AttachedAxisNode
*/
class AttachedAxis : public DenseVariableAxis {
public:
TVM_DLL explicit AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr);
TVM_DEFINE_OBJECT_REF_METHODS(AttachedAxis, DenseVariableAxis, AttachedAxisNode);
};

/*!
* \brief Sparse axis with fixed number of non-zero columns per row.
*/
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""TVM Script Parser Special Stmt Classes"""
# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
# pylint: disable=relative-beyond-top-level
from os import name
from typing import Callable, List, Optional, Tuple, Any, Mapping, Union

import synr
Expand All @@ -35,6 +36,7 @@
DenseVariableAxis,
SparseFixedAxis,
SparseVariableAxis,
AttachedAxis,
)

from .node import BufferSlice
Expand Down Expand Up @@ -946,6 +948,38 @@ def dense_variable(
super().__init__(dense_variable, def_symbol=True)


@register
class Attach(SpecialStmt):
"""Special Stmt for attaching axis."""

def __init__(self):
def attach_axis(
parent: Axis,
orig: Axis,
nnz: PrimExpr,
indptr_var: tvm.tir.Var,
idtype: str = "int32",
span: Optional[Span] = None,
):
names = [x.id.name for x in self.node.lhs]
if len(names) != 1:
self.context.report_error(
f"`attach_axis` expected assign to only one var, but got {names}", span
)

indptr_len = orig.nnz + 1
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = AttachedAxis(names[0], parent, orig, nnz, indptr_buf)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)

super().__init__(attach_axis, def_symbol=True)


@register
class SparseFixed(SpecialStmt):
"""Special Stmt for creating sparse fixed axis."""
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,36 @@ def __init__(self, name, parent, length, nnz, indptr):
)


@tvm._ffi.register_object("tir.sparse.AttachedAxis")
class AttachedAxis(DenseVariableAxis):
"""AttachedAxis node
Parameters
----------
name : str
The name of the axis.
parent : Axis
The axis to attach to.
orig : Axis
The axis to be attached.
nnz : PrimExpr
The number of nonzeros of the returned axis.
indptr : PrimExpr
The new indptr array of the the returned axis.
"""

name : str
parent : Axis
orig : Axis
nnz : PrimExpr
indptr : PrimExpr

def __init__(self, name, parent, length, nnz, indptr):
self.__init_handle_by_constructor__(
_ffi_api.AttachedAxis, name, parent, length, nnz, indptr
)


@tvm._ffi.register_object("tir.sparse.SparseFixedAxis")
class SparseFixedAxis(DenseAxis):
"""SparseFixedAxis node
Expand Down
50 changes: 26 additions & 24 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,14 +530,12 @@ Doc TVMScriptPrinter::AllocAxis(const Axis& axis) {
return it->second;
}
Doc val;
const auto* df_axis = axis.as<DenseFixedAxisNode>();

if (df_axis != nullptr && df_axis->is_derived_axis) {
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")");
} else {
CHECK(false) << "Cannot allocate fused axis";
}
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
// DenseFromSparseAxis is a temporally defined axis.
val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")");
} else if (axis.as<FusedAxisNode>()) {
// FusedAxis is also a temporally defined axis.
CHECK(false) << "Cannot allocate fused axis";
} else {
std::string name = axis->name;
if (name.length() == 0 || !std::isalnum(name[0])) {
Expand Down Expand Up @@ -1396,19 +1394,16 @@ Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) {
Doc iter_doc;

std::string axis_repr = sp_iter->axis->name;
if (axis->is_derived_axis) {
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")";
if (const DenseFromSparseAxisNode* dfs_axis = axis.as<DenseFromSparseAxisNode>()) {
iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")";
} else if (const FusedAxisNode* fused_axis = axis.as<FusedAxisNode>()) {
std::string orig_axis_name = fused_axis->group[fused_axis->index]->name;
if (fused_axis->index == 0) {
iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name;
} else if (fused_axis->index == int(fused_axis->group.size() - 1)) {
iter_doc << orig_axis_name << ")";
} else {
const FusedAxisNode* fused_axis = axis.as<FusedAxisNode>();
std::string orig_axis_name = fused_axis->group[fused_axis->index]->name;
if (fused_axis->index == 0) {
iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name;
} else if (fused_axis->index == fused_axis->group.size() - 1) {
iter_doc << orig_axis_name << ")";
} else {
iter_doc << orig_axis_name;
}
iter_doc << orig_axis_name;
}
} else {
iter_doc << axis->name;
Expand Down Expand Up @@ -1481,10 +1476,17 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
ICHECK_EQ(params.size(), 0);
doc << "dense_fixed(" << Print(df_axis->length) << ")";
} else if (const auto* dv_axis = obj.as<DenseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length) << ", "
<< Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
if (const auto* attached_axis = obj.as<AttachedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "attach_axis(" << attached_axis->parent_->name << ", " << attached_axis->orig_->name
<< ", " << Print(attached_axis->nnz()) << ", " << Print(params[0]) << ", "
<< PrintDType(attached_axis->indptr->dtype) << ")";
} else {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length)
<< ", " << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
}
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "sparse_fixed(" << sf_axis->parent_->name << ", (" << Print(sf_axis->length) << ", "
Expand Down
100 changes: 69 additions & 31 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis)

TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->nnz(); });

/******** AxisNode ********/

/*! \brief Implementation of get root axis function. */
Axis AxisNode::GetRootAxis() const {
Optional<Axis> parent = GetParentAxis();
if (parent.defined()) {
return parent.value()->GetRootAxis();
} else {
return GetRef<Axis>(this);
}
}

/******** DenseFixedAxis ********/

/*! \brief Default constructor of DenseFixedAxis */
Expand All @@ -67,43 +79,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "dense_fixed(" << op->name << ", " << op->length << ")";
});

/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
node->length = std::move(length);
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
.set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) {
return DenseVariableAxis(std::move(name), std::move(parent), std::move(length),
std::move(nnz), std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DenseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DenseVariableAxisNode*>(node.get());
p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name
<< ")";
});

/******** DenseFromSparseAxis ********/

/*! \brief Default constructor of DenseFromSparseAxis */
DenseFromSparseAxis::DenseFromSparseAxis(SparseAxis base) {
ObjectPtr<DenseFromSparseAxisNode> node = make_object<DenseFromSparseAxisNode>();
node->name = base->name + "_dense";
node->length = base->length;
node->is_derived_axis = true;
node->base = std::move(base);
data_ = std::move(node);
}
Expand Down Expand Up @@ -135,7 +117,6 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
}
node->name = "fused_" + fused_name + "_" + group[index]->name;
node->length = group[index]->nnz();
node->is_derived_axis = true;
node->group = std::move(group);
node->index = index;
data_ = std::move(node);
Expand Down Expand Up @@ -163,6 +144,63 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
node->length = std::move(length);
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
.set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) {
return DenseVariableAxis(std::move(name), std::move(parent), std::move(length),
std::move(nnz), std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DenseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DenseVariableAxisNode*>(node.get());
p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name
<< ")";
});

/******** AttachedAxis ********/
/*! \brief Default constructor of AttachedAxis */
AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) {
ObjectPtr<AttachedAxisNode> node = make_object<AttachedAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
node->orig_ = std::move(orig);
node->length = node->orig_->length;
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(AttachedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis")
.set_body_typed([](String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) {
return AttachedAxis(std::move(name), std::move(parent), std::move(orig), std::move(nnz),
std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttachedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AttachedAxisNode*>(node.get());
p->stream << "attached_axis(" << op->name << ", " << op->length << ", " << op->indptr->name
<< ")";
});

/******** SparseFixedAxis ********/

/*! \brief Default constructor of SparseFixedAxis */
Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ PrimExpr AggregateOffset(PrimExpr prev_offset, Axis axis, PrimExpr index,
break;
}
case AxisKind::kDenseVariable: {
// TODO(zihao): finish the aggregating offset for attached axis.
auto dv_axis = axis.as<DenseVariableAxisNode>();
new_offset = add(BufferLoad(dv_axis->indptr, {std::move(prev_offset)}), std::move(index));
break;
Expand Down
Loading

0 comments on commit 184d212

Please sign in to comment.