Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add software pipelining #10066

Merged
merged 18 commits into from
Feb 18, 2022
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,12 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*! \brief Mark the stage of a statement in the software pipeline */
constexpr const char* software_pipeline_stage = "software_pipeline_stage";

/*! \brief Mark the order of a statement in the software pipeline */
constexpr const char* software_pipeline_order = "software_pipeline_order";

/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

Expand Down
33 changes: 33 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,39 @@ TVM_DLL Pass ConvertForLoopsToSerial();
*/
TVM_DLL Pass UnifiedStaticMemoryPlanner();

/*!
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* \brief Transform annotated loops into pipelined one that ovarlaps producers and consumers.
*
* This pass detects loops with the software pipeline annotations and rewrite them to pipelined
* ones. The behavior of such rewriting depending on two annotations on the loop,
* attr::software_pipeline_stage, and attr::software_pipeline_order, which defines the stage and the
* order, respectively, of the components of the software pipeline. The components of the software
* pipeline is the direct children (ignoring BlockRealize / Block / SeqStmt) of the annotated loop.
* The value of the both annotations should be array of integers, with its size the same as the
* number of the components.
*
* The result of the rewriting is a block that has three blocks as its direct children which
* represents the prologue, the body, and the epilogue of the software pipeline. In the prologue,
* only components whose stage is less than max_stage will be executed. In the epilogue, only
* components whose stage is greater than 0 will be executed. In the body, all the components will
* be executed. Such rewriting enables behavior like prefetching, the components are not necessarily
* executed in the original order. attr::software_pipeline_order defines the order of the each
* component. Components belong to different stages can be reordered.
*
* Buffer allocated inside the software pipeline may be resized to accommodate multiple versions
* of the original buffer. Block annotation attr::double_buffer_scope can be used to indicate that
* the block need to write in the double-buffering style.
*
* Annotations:
* attr::software_pipeline_stage: Array of non-negative integers, each element should be in range
* [0, max_stage], where max_stage is the maximum (inclusive) stage.
* attr::software_pipeline_order: Array of non-negative integers, should be a permutation of
* [0, 1, ..., num_components - 1].
*
* \return The IR transform pass.
*/
TVM_DLL Pass InjectSoftwarePipeline();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,3 +749,14 @@ def ConvertForLoopsToSerial():
The result pass
"""
return _ffi_api.ConvertForLoopsToSerial() # type: ignore


def InjectSoftwarePipeline():
"""Transform annotated loops into pipelined one that parallelize producers and consumers

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectSoftwarePipeline() # type: ignore
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
Expand Down
Loading