-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Thread Shape analysis #186
Conversation
c3c58b3
to
1eb7616
Compare
The motivation of this pass is to generalize the register analysis pass which is used to determine the thread shape of TKW.Register, to all other operations. One main use case for such is to allow reduction, and later on "broadcast" to use thread shape information from the kernel as opposed to relying on vector_shape which may not always be valid. We generalize the register analysis metho by finding a few anchor ops who's thread shape information is determined, and then propagate to it's successors and ancestors. Signed-off-by: Stanley Winata <stanley.winata@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks great! Thanks I think this will be a good foundation to build on. Just some minor comments, but otherwise looks good!
return not isinstance(get_custom(node), nonPropagatableTypes) | ||
|
||
anchor_ops = trace.walk(is_anchor_op) | ||
thread_size_to_ops: dict[IndexSymbol, set[CustomOp]] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty important variable. Can you add some documentation describing the key and value with some samples like in general it could look like
{frozenset({IndexSize(index=M, size=1), IndexSize(index=K, size=4)}): {read_shared_0_0_0}, frozenset({IndexSize(index=N, size=1), IndexSize(index=K, size=4)}): {read_shared_0_0_0}, frozenset({IndexSize(index=N, size=1), IndexSize(index=M, size=4)}): {acc_0_0_0}, frozenset({IndexSize(index=M, size=1), IndexSize(index=N, size=1)}): {}}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more docs, lmk what you think! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, thanks!
|
||
|
||
# Function called on op post propagation for extra processing/handling. | ||
def post_propagation(custom: CustomOp, target_index_sizes: list[IndexSize]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this function required? Don't the forward and backward slices go beyond the boundaries of the reduction and propagate from iter args to init args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backward slice does not go through from IterArg to init args. Shall we modify getInputs to do that instead(it's quite doable)? Since I am not the only user of bfs/backwardSlice I'd like to be mindful of modifiying too much haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@harsh-nod I modified it, but not 100% sure this would work for other's use cases, would probably want your eyes on this one :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think that makes sense. Looks good to me!
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
@@ -57,6 +50,38 @@ def determine_thread_shapes(trace: CapturedTrace): | |||
3. We bucket these ops to Variadic(Index->elem_per_thread) mapping. | |||
4. At every bucket of (index -> elem_per_thread), we apply these information | |||
by updating their indexSequence size. | |||
|
|||
We stored the buckets above in a variable/dict called `thread_size_to_ops`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice! thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! once you get the tests passing :)
The motivation of this pass is to generalize the register analysis pass which is used to determine the thread shape of TKW.Register, to all other operations. One main use case for such is to allow reduction, and later on "broadcast" to use thread shape information from the kernel as opposed to relying on vector_shape which may not always be valid. We generalize the register analysis metho by finding a few anchor ops who's thread shape information is determined, and then propagate to it's successors and ancestors. In addition to that we also implemented a couple helper function/attributes. 1. Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it easier for us to control/stop the search when we hit ops we do not want to explore. In this case, we do not want to explore/propagate onto other anchor ops and their children. 2. Introducing parent_op to IterArg and region of Reduction, for developer ergonomics. 3. Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input exploration phase to be handled individually as opposed to being handled when its' consumer is being explored. Previously to explore/propagate IterArg/GetUser, we need to explore its' consumer, just exploring IterArg/GetUser will not get handled correctly. This is useful for the case where we want to propagate/explore mma.acc (usually IterArg) directly. --------- Signed-off-by: Stanley Winata <stanley.winata@amd.com>
The motivation of this pass is to generalize the register analysis pass which is used to determine the thread shape of TKW.Register, to all other operations.
One main use case for such is to allow reduction, and later on "broadcast" to use thread shape information from the kernel as opposed to relying on vector_shape which may not always be valid.
We generalize the register analysis metho by finding a few anchor ops who's thread shape information is determined, and then propagate to it's successors and ancestors.
In addition to that we also implemented a couple helper function/attributes.
Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it easier for us to control/stop the search when we hit ops we do not want to explore. In this case, we do not want to explore/propagate onto other anchor ops and their children.
Introducing parent_op to IterArg and region of Reduction, for developer ergonomics.
Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input exploration phase to be handled individually as opposed to being handled when its' consumer is being explored. Previously to explore/propagate IterArg/GetUser, we need to explore its' consumer, just exploring IterArg/GetUser will not get handled correctly. This is useful for the case where we want to propagate/explore mma.acc (usually IterArg) directly.