Skip to content

Commit

Permalink
[SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (
Browse files Browse the repository at this point in the history
apache#13)

* Enhance SparseBlock to have enough PrimFunc info

* Remove `func_sparse_buffer_map_`

* Don't print the map uh-huh
  • Loading branch information
MasterJH5574 committed Dec 22, 2021
1 parent 0194186 commit db83eea
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 30 deletions.
19 changes: 10 additions & 9 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,8 @@ class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse buffers defined in the block. */
Array<SparseBuffer> sp_buffers;
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
Map<ObjectRef, Array<Var>> sp_struct2param_map;
/*! \brief The name of the block */
String name;
/*! \brief The body of the block */
Expand All @@ -1296,20 +1296,21 @@ class SparseBlockNode : public StmtNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("sp_iter_vars", &sp_iter_vars);
v->Visit("sp_buffers", &sp_buffers);
v->Visit("sp_struct2param_map", &sp_struct2param_map);
v->Visit("name", &name);
v->Visit("body", &body);
v->Visit("init", &init);
}

bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) &&
equal(name, other->name) && equal(body, other->body) && equal(init, other->init);
return equal(sp_iter_vars, other->sp_iter_vars) &&
equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) &&
equal(body, other->body) && equal(init, other->init);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(sp_iter_vars);
hash_reduce(sp_buffers);
hash_reduce(sp_struct2param_map);
hash_reduce(name);
hash_reduce(body);
hash_reduce(init);
Expand All @@ -1325,9 +1326,9 @@ class SparseBlockNode : public StmtNode {
*/
class SparseBlock : public Stmt {
public:
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
String name, Stmt body, Optional<Stmt> init = NullOpt,
Span span = Span());
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name,
Stmt body, Optional<Stmt> init = NullOpt, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ class ContextMaintainer:
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {}
"""Mapping[Var, SparseBuffer]: The function sparse buffer map"""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
"""Mapping[Var, str]: The map from var to env thread"""

# sparse block context
sp_struct2param_map: Mapping[Object, List[Var]] = {}
"""Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters"""

# parser and analyzer
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
"""tvm.arith.Analyzer: The analyzer for simplifying"""
Expand All @@ -154,9 +156,10 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# function context
self.func_params = []
self.func_buffer_map = {}
self.func_sparse_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# sparse block context
self.sp_struct2param_map = {}
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()
Expand Down
11 changes: 5 additions & 6 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ def __init__(self):
def dense_fixed(name: str, length: PrimExpr, span: Optional[Span] = None):
var_name = self.node.lhs[0].id.name
axis = DenseFixedAxis(name, length)
self.context.sp_struct2param_map[axis] = []
self.context.update_symbol(var_name, axis, self.node)

super().__init__(dense_fixed, def_symbol=True)
Expand All @@ -926,7 +927,7 @@ def dense_variable(
(indptr_len,), dtype=idtype, name=name + "_indptr", span=span
)
axis = DenseVariableAxis(name, length, indptr_buf)
self.context.func_buffer_map[indptr_var] = indptr_buf
self.context.sp_struct2param_map[axis] = indptr_var
self.context.update_symbol(var_name, axis, self.node)
self.context.update_symbol(name + "_indptr", indptr_buf, self.node)

Expand All @@ -951,7 +952,7 @@ def sparse_fixed(
(nnz,), dtype=idtype, name=name + "_indices", span=span
)
axis = SparseFixedAxis(name, length, indices_buf, nnz_cols)
self.context.func_buffer_map[indices_var] = indices_buf
self.context.sp_struct2param_map[axis] = [indices_var]
self.context.update_symbol(var_name, axis, self.node)
self.context.update_symbol(name + "_indices", indices_buf, self.node)

Expand Down Expand Up @@ -980,8 +981,7 @@ def sparse_variable(
(nnz,), dtype=idtype, name=name + "_indices", span=span
)
axis = SparseVariableAxis(name, length, indptr_buf, indices_buf)
self.context.func_buffer_map[indices_var] = indices_buf
self.context.func_buffer_map[indptr_var] = indptr_buf
self.context.sp_struct2param_map[axis] = [indptr_var, indices_var]
self.context.update_symbol(var_name, axis, self.node)
self.context.update_symbol(name + "_indptr", indptr_buf, self.node)
self.context.update_symbol(name + "_indices", indices_buf, self.node)
Expand Down Expand Up @@ -1017,8 +1017,7 @@ def match_sparse_buffer(
if param in self.context.func_params:
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
self.context.func_buffer_map[param] = data
self.context.func_sparse_buffer_map[param] = buffer
self.context.sp_struct2param_map[buffer] = [param]
self.context.update_symbol(buffer_name + "_data", data, self.node)
self.context.update_symbol(buffer_name, buffer, self.node)
else:
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from . import _ffi_api
from .buffer import Buffer
from .expr import IterVar
from .expr import Var, IterVar
from .sparse import SpIterVar, SparseBuffer


Expand Down Expand Up @@ -624,8 +624,8 @@ class SparseBlock(Stmt):
sp_iter_vars : List[SpIterVar]
The sparse iteration variables of the block.
sp_buffers : List[SparseBuffer]
The sparse buffers defined in the block.
sp_struct2param_map : Mapping[Object, List[Var]]
The mapping from sparse data structures to the PrimFunc parameters.
name : str
The name of the block.
Expand All @@ -641,7 +641,7 @@ class SparseBlock(Stmt):
"""

sp_iter_vars: List[SpIterVar]
sp_buffers: List[SparseBuffer]
sp_struct2param_map: Mapping[Object, List[Var]]
name: str
body: Stmt
init: Optional[Stmt]
Expand All @@ -650,7 +650,7 @@ class SparseBlock(Stmt):
def __init__(
self,
sp_iter_vars: List[SpIterVar],
sp_buffers: List[SparseBuffer],
sp_struct2param_map: Mapping[Object, List[Var]],
name: str,
body: Stmt,
init: Optional[Stmt] = None,
Expand All @@ -659,7 +659,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SparseBlock, # type: ignore
sp_iter_vars,
sp_buffers,
sp_struct2param_map,
name,
body,
init,
Expand Down
14 changes: 8 additions & 6 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -975,11 +975,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
Stmt body, Optional<Stmt> init, Span span) {
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
Optional<Stmt> init, Span span) {
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
node->sp_iter_vars = std::move(sp_iter_vars);
node->sp_buffers = std::move(sp_buffers);
node->sp_struct2param_map = std::move(sp_struct2param_map);
node->name = std::move(name);
node->body = std::move(body);
node->init = std::move(init);
Expand All @@ -988,9 +989,10 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_b
}

TVM_REGISTER_GLOBAL("tir.SparseBlock")
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
Stmt body, Optional<Stmt> init, Span span) {
return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span);
.set_body_typed([](Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
Optional<Stmt> init, Span span) {
return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span);
});

TVM_REGISTER_NODE_TYPE(SparseBlockNode);
Expand Down

0 comments on commit db83eea

Please sign in to comment.