Skip to content

Commit

Permalink
[SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (#23)
Browse files Browse the repository at this point in the history
* Test initialization

* Fix a stupid bug of ReprPrinter

* Add SparseBlockRV

* Schedule: GetSparseBlock

* Schedule: Reorder
  • Loading branch information
MasterJH5574 authored and yzh119 committed Jan 21, 2022
1 parent ed1c0d8 commit ee7b1a8
Show file tree
Hide file tree
Showing 14 changed files with 849 additions and 174 deletions.
50 changes: 50 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>
#include <tvm/tir/sparse.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -85,6 +86,27 @@ using ExprRV = PrimExpr;

using ExprRVNode = PrimExprNode;

/**************** Random variable: SparseBlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR sparse block */
class SparseBlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.SparseBlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to SparseBlockRVNode
* \sa SparseBlockRVNode
*/
class SparseBlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL SparseBlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode);
};

/**************** The Schedule class ****************/

class Schedule;
Expand Down Expand Up @@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding expr
*/
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
/*!
* \brief Get the sparse block corresponding to the specific random variable
* \param sp_block_rv The random variable to be looked up
* \return SparseBlock The corresponding sparse block
*/
virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
Expand Down Expand Up @@ -188,6 +216,11 @@ class ScheduleNode : public runtime::Object {
* \param expr_rv The random variable to be removed
*/
virtual void RemoveRV(const ExprRV& expr_rv) = 0;
/*!
* \brief Remove an sparse block random variable from the symbol table
* \param sp_block_rv The random variable to be removed
*/
virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0;

public:
/******** Schedule: Sampling ********/
Expand Down Expand Up @@ -505,6 +538,23 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
/******** Schedule: SparseTIR schedules ********/
/*!
* \brief Retrieve a sparse block in a specific function with its name
* \param name The name of the sparse block to be retrieved
* \param func_name The name of the function
* \return The sparse block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
* dependency.
* \param block The block to be transformed
* \param new_order The new order of the sparse iterators, whose length should equal to the number
* of the input block's sparse iterators
*/
virtual void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) = 0;
};

/*!
Expand Down
74 changes: 67 additions & 7 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, SparseBlock
from tvm.tir.sparse import SpIterVar

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -56,12 +57,23 @@ def __init__(self) -> None:
)


@_register_object("tir.SparseBlockRV")
class SparseBlockRV(Object):
"""A random variable that refers to a sparse block"""

def __init__(self) -> None:
"""Construct a new SparseBlockRV."""
self.__init_handle_by_constructor__(
_ffi_api.SparseBlockRV # type: ignore # pylint: disable=no-member
)


# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370
# This feature is not supported until python 3.10:
# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV, SparseBlockRV] # pylint: disable=invalid-name

# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
_ERROR_RENDER_LEVEL: Dict[str, int] = {
Expand Down Expand Up @@ -227,7 +239,7 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
Parameters
----------
rand_var : Union[ExprRV, BlockRV, LoopRV]
rand_var : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV]
The random variable to be evaluated
Returns
Expand All @@ -243,22 +255,23 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
def get(
self,
rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef],
) -> Optional[Union[int, Block, For]]:
) -> Optional[Union[int, Block, For, SparseBlock]]:
"""Returns:
- the corresponding Block that a BlockRV evaluates to;
- the corresponding For that a LoopRV evaluates to;
- the corresponding integer that a ExprRV evaluates to;
- the corresponding SparseBlock that a SparseBlockRV evaluates to;
- the corresponding Block that a block sref points to;
- the corresponding For that a loop sref points to;
Parameters
----------
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef]
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV, StmtSRef]
The random variable / sref to be evaluated
Returns
-------
result : Optional[Union[int, Block, For]]
result : Optional[Union[int, Block, For, SparseBlock]]
The corresponding result
"""
if isinstance(rand_var_or_sref, StmtSRef):
Expand Down Expand Up @@ -296,7 +309,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
Parameters
----------
rand_var : Union[BlockRV, LoopRV, ExprRV]
rand_var : Union[BlockRV, LoopRV, ExprRV, SparseBlockRV]
The random variable to be removed
"""
return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member
Expand Down Expand Up @@ -1888,3 +1901,50 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
def enter_postproc(self) -> None:
"""A no-op that marks the start of postprocessing phase of scheduling"""
_ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member

########## Schedule: SparseTIR schedules ##########

def get_sparse_block(
self,
name: str,
func_name: str = "main",
) -> SparseBlock:
"""Retrieve a sparse block in a specific function with its name
Parameters
----------
name : str
The name of the sparse block
func_name : str = "main"
The name of the function
Returns
-------
block : SparseBlockRV
The sparse block retrieved
IndexError is raised if 0 or multiple blocks exist with the specific name.
"""
return _ffi_api.ScheduleGetSparseBlock( # type: ignore # pylint: disable=no-member
self,
name,
func_name,
)

def sparse_reorder(self, block: SparseBlockRV, new_order: List[SpIterVar]) -> None:
"""Reorder a list of sparse iterators. It requires the new order to not break the iterator
dependency.
Parameters
----------
block : SparseBlockRV
The queried sparse block
new_order : List[SpIterVar]
The The new order of the sparse iterators, whose length should equal to the number
of the input block's sparse iterators
"""
return _ffi_api.ScheduleSparseReorder( # type: ignore # pylint: disable=no-member
self,
block,
new_order,
)
2 changes: 1 addition & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
auto* op = static_cast<const SparseBufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ void VerifyCachedFlags(const ScheduleState& self);
const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
GlobalVar* result_g_var);

/*!
* \brief Get PrimFunc and GlobalVar that the sparse block belongs to
* \param mod The IRModule
* \param sp_block The sparse block inside the PrimFunc to be queried
* \param result_g_var The result GlobalVar
* \return The result PrimFunc where the sparse block belongs to
* \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
*/
const PrimFuncNode* GetPrimFuncFromSparseBlock(const IRModule& mod, const SparseBlockNode* sp_block,
GlobalVar* result_g_var);

/*!
* \brief Get the root node of the sref tree, which is the root block of the PrimFunc.
* \param sref The given sref.
Expand Down
20 changes: 20 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl
throw;
}

const PrimFuncNode* GetPrimFuncFromSparseBlock(const IRModule& mod, const SparseBlockNode* sp_block,
GlobalVar* result_g_var) {
for (const auto& kv : mod->functions) {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* func = base_func.as<PrimFuncNode>()) {
if (func->body.get() == sp_block) {
if (result_g_var != nullptr) {
*result_g_var = g_var;
}
return func;
}
}
}
LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the "
"sparse block:\n"
<< GetRef<SparseBlock>(sp_block);
throw;
}

/******** Scope ********/

StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, //
Expand Down
46 changes: 46 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -664,5 +664,51 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_

/******** Schedule: Misc ********/

/******** Schedule: SparseTIR schedules ********/

SparseBlockRV ConcreteScheduleNode::GetSparseBlock(const String& name, const String& func_name) {
class NotFoundResult : public ScheduleError {
public:
explicit NotFoundResult(String name, IRModule mod) : name_(name), mod_(mod) {}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

String DetailRenderTemplate() const final {
return "Cannot find a sparse block with the name: " + name_;
}

String FastErrorString() const final {
return "ScheduleError: Cannot find a sparse block with the specified name";
}

String name_;
IRModule mod_;
};

BaseFunc func = this->state_->mod->Lookup(func_name);
const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode);

// Currently we only handle cases with single sparse block.
const auto* block = prim_func->body.as<SparseBlockNode>();
if (block == nullptr) {
TVM_TIR_SCHEDULE_BEGIN();
throw NotFoundResult(name, this->state_->mod);
TVM_TIR_SCHEDULE_END("get-sparse-block", this->error_render_level_);
}

return CreateRV(GetRef<SparseBlock>(block));
}

void ConcreteScheduleNode::SparseReorder(const SparseBlockRV& block_rv,
const Array<SpIterVar>& new_order) {
SparseBlock old_block = this->Get(block_rv);
SparseBlock new_block{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
new_block = tir::SparseReorder(state_, old_block, new_order);
TVM_TIR_SCHEDULE_END("sparse-reorder", this->error_render_level_);
this->UpdateRV(block_rv, new_block);
}

} // namespace tir
} // namespace tvm
34 changes: 34 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ConcreteScheduleNode : public ScheduleNode {
inline Block Get(const BlockRV& block_rv) const final;
inline For Get(const LoopRV& loop_rv) const final;
inline PrimExpr Get(const ExprRV& expr_rv) const final;
inline SparseBlock Get(const SparseBlockRV& sp_block_rv) const final;
inline StmtSRef GetSRef(const BlockRV& block_rv) const final;
inline StmtSRef GetSRef(const LoopRV& loop_rv) const final;
inline bool HasBlock(const BlockRV& block_rv) const final;
Expand All @@ -78,6 +79,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); }
void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); }
void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); }
void RemoveRV(const SparseBlockRV& sp_block_rv) final { RemoveFromSymbolTable(sp_block_rv); }
using ScheduleNode::GetSRef;

public:
Expand Down Expand Up @@ -131,6 +133,9 @@ class ConcreteScheduleNode : public ScheduleNode {

/******** Schedule: Misc ********/
void EnterPostproc() override {}
/******** Schedule: SparseTIR schedules ********/
SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") override;
void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) override;

protected:
/******** Utility functions ********/
Expand Down Expand Up @@ -168,6 +173,18 @@ class ConcreteScheduleNode : public ScheduleNode {
* \return The new random variables created
*/
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
/*!
* \brief Add a sparse block as a random variable into the symbol table
* \param sp_block
* \return SparseBlockRV
*/
inline SparseBlockRV CreateRV(const SparseBlock& sp_block);
/*!
* \brief Update the value of the input SparseBlockRV to the input block.
* \param sp_block_rv The random variable to be updated
* \param block The new value of the random variable
*/
inline void UpdateRV(const SparseBlockRV& sp_block_rv, const SparseBlock& block);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
/*!
Expand Down Expand Up @@ -208,6 +225,13 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
return this->analyzer_->Simplify(transformed);
}

inline SparseBlock ConcreteScheduleNode::Get(const SparseBlockRV& sp_block_rv) const {
auto it = this->symbol_table_.find(sp_block_rv);
CHECK(it != this->symbol_table_.end())
<< "IndexError: Cannot find corresponding SparseBlockRV: " << sp_block_rv;
return Downcast<SparseBlock>((*it).second);
}

inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const {
auto it = this->symbol_table_.find(block_rv);
if (it == this->symbol_table_.end()) {
Expand Down Expand Up @@ -317,6 +341,16 @@ inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>&
return results;
}

inline SparseBlockRV ConcreteScheduleNode::CreateRV(const SparseBlock& block) {
SparseBlockRV rv;
this->symbol_table_.Set(rv, block);
return rv;
}

inline void ConcreteScheduleNode::UpdateRV(const SparseBlockRV& rv, const SparseBlock& block) {
this->symbol_table_.Set(rv, block);
}

inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) {
auto it = this->symbol_table_.find(obj);
if (it != this->symbol_table_.end()) {
Expand Down
Loading

0 comments on commit ee7b1a8

Please sign in to comment.