Skip to content

Commit

Permalink
[SparseTIR] Introduce SpIterVar (#6)
Browse files Browse the repository at this point in the history
* [SparseTIR] Introduce SpIterVar

* Add conversion to PrimExpr
  • Loading branch information
MasterJH5574 committed Nov 10, 2021
1 parent cb131f5 commit af0a7d4
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 1 deletion.
58 changes: 58 additions & 0 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,64 @@ class SparseBuffer : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};

enum class SpIterKind : int {
kDenseFixed = 0,
kDenseVariable = 1,
kSparseFixed = 2,
kSparseVariable = 3
};

/*!
* \brief Iterator variables in SparseTIR
*/
class SpIterVarNode : public Object {
public:
Var var;
PrimExpr max_extent;
SpIterKind kind;
Optional<Axis> axis;

void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("max_extent", &max_extent);
v->Visit("axis", &axis);
v->Visit("kind", &kind);
}

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) && equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(var);
hash_reduce(max_extent);
hash_reduce(axis);
hash_reduce(kind);
}

static constexpr const char* _type_key = "tir.sparse.SpIterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(SpIterVarNode, Object);
};

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
Optional<Axis> axis = NullOpt);

/*!
* \return the corresponding var in the IterVar.
*/
inline operator PrimExpr() const;

TVM_DEFINE_OBJECT_REF_METHODS(SpIterVar, ObjectRef, SpIterVarNode);
};

// inline implementations
inline SpIterVar::operator PrimExpr() const { return (*this)->var; }

} // namespace tir
} // namespace tvm

Expand Down
39 changes: 38 additions & 1 deletion python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm._ffi
from tvm.ir import PrimExpr
from tvm.runtime import Object, const
from tvm.tir import Var

from . import _ffi_api
from .buffer import Buffer
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(self, axis_parent_map) -> None:


@tvm._ffi.register_object("tir.sparse.SparseBuffer")
class SparseBuffer:
class SparseBuffer(Object):
"""SparseBuffer node
Parameters
Expand Down Expand Up @@ -197,3 +198,39 @@ def __init__(self, tree, axes, data, name, dtype=None):
self.__init_handle_by_constructor__(
_ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore
)


@tvm._ffi.register_object("tir.sparse.SpIterVar")
class SpIterVar(Object):
"""IterVar in SparseTIR
Parameters
----------
var : Var
The var of the SpIterVar
max_extent : PrimExpr
The maximum extent of the SpIterVar
kind : int
The kind of the SpIterVar
axis : Optional[Axis]
The axis over which the SpIterVar iterates. Required to be defined
when `kind` is not `DenseFixed`
"""
var: Var
max_extent: PrimExpr
kind: int
axis: Optional[Axis]

DenseFixed = 0
DenseVariable = 1
SparseFixed = 2
SparseVariable = 3

def __init__(self, var, max_extent, kind, axis=None):
self.__init_handle_by_constructor__(
_ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore
)

23 changes: 23 additions & 0 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,28 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
return SparseBuffer(tree, axes, data, name, dtype);
});

// SpIterVar
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

if (kind != SpIterKind::kDenseFixed) {
CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must "
"specify the axis over which the SpIterVar iterates";
}

node->var = Var(std::move(name));
node->max_extent = std::move(max_extent);
node->kind = kind;
node->axis = std::move(axis);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
return SpIterVar(name, max_extent, kind, axis);
});

} // namespace tir
} // namespace tvm

0 comments on commit af0a7d4

Please sign in to comment.