-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Asynchronous stage in software pipeline #12171
Conversation
e656cbe
to
1baf10d
Compare
@@ -384,6 +426,9 @@ class ThreadSyncInserter : public StmtExprMutator { | |||
|
|||
Stmt ThreadSync(Stmt stmt, std::string storage_scope) { | |||
StorageScope sync_scope = StorageScope::Create(storage_scope); | |||
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to check sync_scope.tag
? I assume it also works for dynamic shared memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only for making sure that this code path is hit only once. ThreadSyncAfterWaitQueueInserter
just looks for async_wait_queue_scope
and inserts syncthreads
after it. So assuming that all shared memory, including dynamic ones, are protected by async_wait_queue_scope
(which should be the case by InjectSoftwarePipeline
), all necessary syncthreads
will be inserted.
Since ThreadSync
is called twice, for shared
and shared.dyn
,
Lines 530 to 531 in 7ef6811
mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); | |
mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); |
syncthreads
without this check.
Thinking about it more now, this assumes that async_wait_queue_scope
on GPU is always associated with shared memory. This should be fine as long as the only async operation is copying into shared memory. I have to admit this is a bit hacky, but something like this is needed for correctness.
new_block = Downcast<Block>( | ||
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); | ||
|
||
if (pipeline_info_[block].async) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we refactor async pipeline related into some functions to make the original EmitImpl
logic more concise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok moved the bulk of logic into two functions. Now EmitImpl
itself is kept short.
|
||
// Given pipelined blocks and async-related information, generate final loop statements with async | ||
// scopes (if any). | ||
Array<Stmt> CompletePipelineLoopStatements( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely happy with the choice of this name, a suggestion for better one welcome.
new_block = Downcast<Block>( | ||
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); | ||
|
||
if (pipeline_info_[block].async) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be great to also refactor this if statement to some functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible but since this code block touches a lot of stuff defined in this loop, the extracted function would look rather messy like this:
void UpdateForAsync(Block block, Block new_block, int stage, size_t new_blocks_size,
PrimExpr normalized_access_index, PrimExpr inbound,
arith::Analyzer* ana_normalized,
std::map<int, AsyncStateLocal>* async_states_local,
std::unordered_map<const BufferNode*, int>* buffer_to_commit_group) {
...
And a reader would need to go back and forth between this function andEmitImpl
anyway to understand the meanings of these variables and how they are used.
So I think making this change would rather hurt the readability.
* [TIR] Support asynchronous stages in software pipeline transform * Support interleaved async producers separated by a consumer * clean up * adding doc * adding doc * simplifying * make wait count computation a two pass process * commit_stage -> commit_queue, wait_stage -> wait_queue * make async_commit_queue special scope stmt * codegen async_commit_queue in cuda * clean up * clean up * Move block predicate outside of commit_queue * updating test * test updated * changed async_wait to an annotation * update doc * update meaning of software_pipeline_async_stages * update test * fixing codegen * more fix * remove one of tests that have async and sync ops in the same stage * format * lint and other fix * Define attr::software_pipeline_async_stages * populate wait count in a separate function * fold variabel consumed into AsyncStateLocal * introduce CompletePipelineLoopStatements function for further refactor
* [TIR] Support asynchronous stages in software pipeline transform * Support interleaved async producers separated by a consumer * clean up * adding doc * adding doc * simplifying * make wait count computation a two pass process * commit_stage -> commit_queue, wait_stage -> wait_queue * make async_commit_queue special scope stmt * codegen async_commit_queue in cuda * clean up * clean up * Move block predicate outside of commit_queue * updating test * test updated * changed async_wait to an annotation * update doc * update meaning of software_pipeline_async_stages * update test * fixing codegen * more fix * remove one of tests that have async and sync ops in the same stage * format * lint and other fix * Define attr::software_pipeline_async_stages * populate wait count in a separate function * fold variabel consumed into AsyncStateLocal * introduce CompletePipelineLoopStatements function for further refactor
This PR implements the asynchronous pipeline feature proposed in apache/tvm-rfcs#80 and lowering for CUDA async global to shared memory copy.
The main change is in
inject_software_pipeline
, where necessary synchronization annotations are inserted according to the user provided list of async stages,software_pipeline_async_stages
.@vinx13 @junrushao1994 @csullivan @JosephTheOctonaut @wrongtest-intellif @kparzysz-quic