Skip to content

Commit

Permalink
Sparse TIR all
Browse files Browse the repository at this point in the history
Format and Buffer data structure (#1)

[SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

[CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

Fix AxisTree (#3)

* fix axis tree

* upd

[SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

[SparseTIR] Introduce SpIterVar (#6)

* [SparseTIR] Introduce SpIterVar

* Add conversion to PrimExpr

[BugFix] Fix binary search & SpIterVar (#7)

[BugFix] Add field `is_reduction` for SpIterVar (#9)

* [BugFix] Add field `is_reduction` for SpIterVar

* Formatting

[SparseTIR] Index Lowering (#8)

* Add StmtFunctor/ExprFunctor for SparseBufferStore/Load

* Add basic index lowering

* Finish index lowering (maybe)

* Address comments

* Convert CRLF to LF

Frontend update, demo scripts. (#10)

* Format and Buffer data structure (#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

* Fix AxisTree (#3)

* fix axis tree

* upd

* Format and Buffer data structure (#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* fix axis tree

* upd

* Format and Buffer data structure (#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

* Fix AxisTree (#3)

* fix axis tree

* upd

* [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

* Format and Buffer data structure (#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <expye@outlook.com>

* Fix AxisTree (#3)

* fix axis tree

* upd

* [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

* [SparseTIR] Introduce SpIterVar (#6)

* [SparseTIR] Introduce SpIterVar

* Add conversion to PrimExpr

* [BugFix] Fix binary search & SpIterVar (#7)

* [BugFix] Add field `is_reduction` for SpIterVar (#9)

* [BugFix] Add field `is_reduction` for SpIterVar

* Formatting

* upd

* upd

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

[SparseTIR] SparseBlock on C++/Python side (#11)

* Fix a bug in the last commit

* SparseBlock on C++ & Python side

[BugFix][SparseTIR] TVMScript Parser for Axis & SpIterVar (#12)

* Update `cord` and `pos`

* Fix `idtype`

* Formatting..

* Bug fix 1

* Move new special stmts

* Parser for Axis and SpIterVar

* Fix context_maintainer.py

[SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (#13)

* Enhance SparseBlock to have enough PrimFunc info

* Remove `func_sparse_buffer_map_`

* Don't print the map uh-huh

[SparseTIR] Parser, Printer, Roundtrip (#14)

* SparseBlock scope handler (part 1)

* SparseBlock scope handler (part 2)

* SparseBlock scope handler (part 3)

* SparseBlock scope handler (fix 1)

* Add SparseBufferLoad/Store on Python side

* Parser for SparseBufferLoad/Store

* Add SparseBlock to Python __init__

* StmtFunctor for SparseBlock

* Ensure at least one dimension for SparseBuffer

* Make `axis` field of SpIterVar mandatory

* SparseBlock scope handler (fix 2)

* Update Axis syntax by removing `name` parameter

* Move to intrin.py

* Add filed `from_sparse` to DenseFixedAxis

* SparseTIR script printer

* Roundtrip test

* `update_symbol` bug fix

* Fix attr visit in SparseBuffer

* Define then compare in SparseBlock

* Fix printer bug for SparseBuffer

* Enable graph match for Axis and SparseBuffer

* Complete HashReduce and EqualReduce for AxisTree and SparseBuffer

* Fix typo

* Rename test

* Bug fix 1

* Bug fix 2

* Add more tests

Move tests (#15)

[SparseTIR] ReprPrinter for Axis and SpIterVar (#16)

upd (#17)

flatten (#18)

ELL and BSR correctness test scripts (#19)

[SparseTIR] SparseTIR Lowering (#20)

* Fix a previous bug of sparse-fixed SpIterVar creation

* Fix a previous bug in `GetDenseValue`

* Refactor Collector and IndexTransformer

* Construct block and loops

* Fix a previous bug which rejects DV iters in collector

* Update buffer map

* Create root block

* Fix bug of sparse-fixed SpIterVar creation

* Fix bug on SpIterVar conversion (with refactor)

* Fix bug when getting dependent SpIterVars

* Fix bug on dependency map and index lowering

* Full block read/write region

* Test version 1

* Fix bug of loop order

* Fix bug of batch-mm iterator ordering

* Update PrimFunc args to use symbolic params

* Fix bug of test "csr_element_wise"

* Fix bug of index accumulation for sparse-fixed axis

* Update correctness test

* Test structural equality

* Refactor and use Array

fix nnz cols

Add docstring for sparse tir lowering (#21)

* add docstring

* upd

Add more examples part 1 (sddmm) (#22)

* upd

* upd

* upd

[SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (#23)

* Test initialization

* Fix a stupid bug of ReprPrinter

* Add SparseBlockRV

* Schedule: GetSparseBlock

* Schedule: Reorder

[SparseTIR][Schedule] GetSpIters (#24)

remove hybrid script for successful compilation

Add atomic intrinsic for output nonzero inference. (#25)

* upd

* upd

Add "sparse" block attribute. (#26)

Revert "remove hybrid script for successful compilation"

This reverts commit eebd7c1.

[SparseTIR] Hack `IsAffineBinding` check (#27)

* [TensorIR][Schedule] Inherit block anotation upon creating new blocks

* Fix SDDMM test

* Hack IsAffineBinding for sparse blocks

Axis Dependency Tree aware code-gen and bmm example (#28)

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* remove redundancy

* fix

* upd

* upd

Re-design Indices lowering (#29)

* upd

* upd

* upd

* upd

* upd

* init

* format

* fix

* revise coding-style

* format

Complete indices lowering (#30)

* upd

* upd

* upd

* done

* upd

* passed test

* upd

Add more docstrings and depress warnings for new lowering algorithm. (#31)

Refactor derived axis, frontend support of fusion. (#32)

* upd

* upd

* fix

Fatal bugfix and change the signature of DenseVariableAxis.  (#33)

Syntax simplification (#34)

Change the order of generated blocks for block isolation. (#35)

* upd

* upd

* upd

Syntax of AttachAxis for BMM (#36)

* upd

* upd

* upd

[SparseTIR] Add "square sum" lowering test (#37)

* Add square sum test

* Remove pylint comment

[BugFix] Fix offset caching in lowering (#38)

* Hack compact dataflow check in a dirty way

* Add two-K square sum test

* Mark skipped tests

* Fix offset saving in lowering

Fusion syntax fix + SDDMM example.  (#39)

Some structure change on update offsets. (#40)

[Refactor] SparseTIR Lowering (#41)

* Take out methods in Scope

* Refactor

* Refactor "match"

* Tweak scope contents

* Refactor ViewIndexInAxis

* Refactor Scope

* SDDMM tests under implementation

* Refactor block stack

* Use Map for var_map

* Extract NeedCreateNewBlock

* Simplify SpIterVarToIterVar via GetIterExtent

* Refactor NeedCreateNewBlock

* Add docstring

* Use "auto" correctly

* Minor refactor and use some move

Remove redundant analyzers (#42)

Support indices lowering for attach and fuse. (#43)

* upd

* upd

* upd

Fix irregular BMM example. (#44)

* upd

* upd

* upd

* upd

RGCN forward and butterfly pattern example. (#45)

Fused SDDMM example. (#46)

* upd

* wip

* fix

Fix sparse reorder after refactor (#47)

[Refactor] Refactor Unittest (#48)

* upd

* remove redundancy

[Unittest] Correctness test for benchmarking scripts (#49)

Bugfix and more test for axis fusion, new workload (#50)

* upd

* upd

upd

upd

upd

upd

upd

upd
  • Loading branch information
yzh119 committed Mar 11, 2022
1 parent 5b76768 commit 191a2fa
Show file tree
Hide file tree
Showing 57 changed files with 6,188 additions and 28 deletions.
3 changes: 2 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ class BufferNode : public Object {
static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
// TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
TVM_DECLARE_BASE_OBJECT_INFO(BufferNode, Object);
};

/*!
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,21 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
TVM_DLL const Op& tvm_warp_shuffle_down();
TVM_DLL const Op& tvm_warp_activemask();

/*!
* \brief Lower bound function for binary search.
*/
TVM_DLL const Op& tvm_lower_bound();

/*!
* \brief Upper bound function for binary search.
*/
TVM_DLL const Op& tvm_upper_bound();

/*!
* \brief Atomic add function.
*/
TVM_DLL const Op& tvm_atomic_add();

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/sparse.h>
#include <tvm/tir/var.h>

#include <algorithm>
Expand Down Expand Up @@ -659,6 +660,7 @@ class BufferLoad : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};


/*!
* \brief Load value from the result produced by the producer.
*
Expand Down
32 changes: 32 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,38 @@ TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
*/
TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());

/*!
* \brief Lower bound function for binary search
* \param arr The buffer variable of the array to be looked up in
* \param val The value to be looked up in the array
* \param l The left boundary of the look-up range (inclusive)
* \param r The right boundary of the look-up range (exclusive)
* \param span The location of this operation in the source
* \return The look-up result
*/
TVM_DLL PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
Span span = Span());

/*!
* \brief Upper bound function for binary search
* \param arr The buffer variable of the array to be looked up in
* \param val The value to be looked up in the array
* \param l The left boundary of the look-up range (inclusive)
* \param r The right boundary of the look-up range (exclusive)
* \param span The location of this operation in the source
* \return The look-up result
*/
TVM_DLL PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
Span span = Span());

/*!
* \brief Perform atomic add on ptr by val, and return the old value.
* \param ptr The address to perform atomic add.
* \param val The value to add.
* \return The old result stored in ptr.
*/
TVM_DLL PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span = Span());

/*!
* \brief Calculate trunc(x)
* \param x The input expression.
Expand Down
56 changes: 56 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>
#include <tvm/tir/sparse.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -85,6 +86,27 @@ using ExprRV = PrimExpr;

using ExprRVNode = PrimExprNode;

/**************** Random variable: SparseBlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR sparse block */
class SparseBlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.SparseBlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to SparseBlockRVNode
* \sa SparseBlockRVNode
*/
class SparseBlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL SparseBlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode);
};

/**************** The Schedule class ****************/

class Schedule;
Expand Down Expand Up @@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding expr
*/
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
/*!
* \brief Get the sparse block corresponding to the specific random variable
* \param sp_block_rv The random variable to be looked up
* \return SparseBlock The corresponding sparse block
*/
virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
Expand Down Expand Up @@ -188,6 +216,11 @@ class ScheduleNode : public runtime::Object {
* \param expr_rv The random variable to be removed
*/
virtual void RemoveRV(const ExprRV& expr_rv) = 0;
/*!
* \brief Remove an sparse block random variable from the symbol table
* \param sp_block_rv The random variable to be removed
*/
virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0;

public:
/******** Schedule: Sampling ********/
Expand Down Expand Up @@ -524,6 +557,29 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
/******** Schedule: SparseTIR schedules ********/
/*!
* \brief Retrieve a sparse block in a specific function with its name
* \param name The name of the sparse block to be retrieved
* \param func_name The name of the function
* \return The sparse block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Retrieve the sparse iterators of a given sparse block
* \param block_rv The block to be queried
* \return The sparse iterators of the input sparse block
*/
virtual Array<SpIterVar> GetSpIters(const SparseBlockRV& block_rv) = 0;
/*!
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
* dependency.
* \param block The block to be transformed
* \param new_order The new order of the sparse iterators, whose length should equal to the number
* of the input block's sparse iterators
*/
virtual void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) = 0;
};

/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ class ScheduleStateNode : public Object {
* \return A boolean flag indicating if the block has quasi-affine bindings
*/
bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
// (SparseTIR Hack) Always return true for sparse blocks.
const auto* block = block_sref->StmtAs<BlockNode>();
Optional<ObjectRef> sparse_attr = block != nullptr ? block->annotations.Get("sparse") : NullOpt;
if (sparse_attr.defined() && sparse_attr.as<IntImmNode>()->value == 1) {
return true;
}

return GetBlockInfo(block_sref).affine_binding;
}
/*!
Expand Down
Loading

0 comments on commit 191a2fa

Please sign in to comment.