Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Thread Shape analysis #186

Merged
merged 4 commits into from
Oct 3, 2024

Conversation

raikonenfnu
Copy link
Contributor

@raikonenfnu raikonenfnu commented Oct 2, 2024

The motivation of this pass is to generalize the register analysis pass which is used to determine the thread shape of TKW.Register, to all other operations.

One main use case for such is to allow reduction, and later on "broadcast" to use thread shape information from the kernel as opposed to relying on vector_shape which may not always be valid.

We generalize the register analysis metho by finding a few anchor ops who's thread shape information is determined, and then propagate to it's successors and ancestors.

In addition to that we also implemented a couple helper function/attributes.

  1. Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it easier for us to control/stop the search when we hit ops we do not want to explore. In this case, we do not want to explore/propagate onto other anchor ops and their children.

  2. Introducing parent_op to IterArg and region of Reduction, for developer ergonomics.

  3. Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input exploration phase to be handled individually as opposed to being handled when its' consumer is being explored. Previously to explore/propagate IterArg/GetUser, we need to explore its' consumer, just exploring IterArg/GetUser will not get handled correctly. This is useful for the case where we want to propagate/explore mma.acc (usually IterArg) directly.

The motivation of this pass is to generalize the register analysis pass
which is used to determine the thread shape of TKW.Register, to all
other operations.

One main use case for such is to allow reduction, and later on
"broadcast" to use thread shape information from the kernel as opposed
to relying on vector_shape which may not always be valid.

We generalize the register analysis metho by finding  a few anchor ops
who's thread shape information is determined,
and then propagate to it's successors and ancestors.

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great! Thanks I think this will be a good foundation to build on. Just some minor comments, but otherwise looks good!

shark_turbine/kernel/wave/thread_shape_analysis.py Outdated Show resolved Hide resolved
return not isinstance(get_custom(node), nonPropagatableTypes)

anchor_ops = trace.walk(is_anchor_op)
thread_size_to_ops: dict[IndexSymbol, set[CustomOp]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty important variable. Can you add some documentation describing the key and value with some samples like in general it could look like

{frozenset({IndexSize(index=M, size=1), IndexSize(index=K, size=4)}): {read_shared_0_0_0}, frozenset({IndexSize(index=N, size=1), IndexSize(index=K, size=4)}): {read_shared_0_0_0}, frozenset({IndexSize(index=N, size=1), IndexSize(index=M, size=4)}): {acc_0_0_0}, frozenset({IndexSize(index=M, size=1), IndexSize(index=N, size=1)}): {}}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more docs, lmk what you think! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, thanks!

shark_turbine/kernel/wave/thread_shape_analysis.py Outdated Show resolved Hide resolved


# Function called on op post propagation for extra processing/handling.
def post_propagation(custom: CustomOp, target_index_sizes: list[IndexSize]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this function required? Don't the forward and backward slices go beyond the boundaries of the reduction and propagate from iter args to init args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward slice does not go through from IterArg to init args. Shall we modify getInputs to do that instead(it's quite doable)? Since I am not the only user of bfs/backwardSlice I'd like to be mindful of modifiying too much haha

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harsh-nod I modified it, but not 100% sure this would work for other's use cases, would probably want your eyes on this one :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think that makes sense. Looks good to me!

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
@@ -57,6 +50,38 @@ def determine_thread_shapes(trace: CapturedTrace):
3. We bucket these ops to Variadic(Index->elem_per_thread) mapping.
4. At every bucket of (index -> elem_per_thread), we apply these information
by updating their indexSequence size.

We stored the buckets above in a variable/dict called `thread_size_to_ops`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice! thanks!

@harsh-nod harsh-nod self-requested a review October 3, 2024 17:11
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! once you get the tests passing :)

@raikonenfnu raikonenfnu merged commit e0a8fdf into iree-org:main Oct 3, 2024
6 of 8 checks passed
stellaraccident pushed a commit that referenced this pull request Oct 13, 2024
The motivation of this pass is to generalize the register analysis pass
which is used to determine the thread shape of TKW.Register, to all
other operations.

One main use case for such is to allow reduction, and later on
"broadcast" to use thread shape information from the kernel as opposed
to relying on vector_shape which may not always be valid.

We generalize the register analysis metho by finding a few anchor ops
who's thread shape information is determined, and then propagate to it's
successors and ancestors.

In addition to that we also implemented a couple helper
function/attributes.

1. Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it
easier for us to control/stop the search when we hit ops we do not want
to explore. In this case, we do not want to explore/propagate onto other
anchor ops and their children.

2. Introducing parent_op to IterArg and region of Reduction, for
developer ergonomics.

3. Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input
exploration phase to be handled individually as opposed to being handled
when its' consumer is being explored. Previously to explore/propagate
IterArg/GetUser, we need to explore its' consumer, just exploring
IterArg/GetUser will not get handled correctly. This is useful for the
case where we want to propagate/explore mma.acc (usually IterArg)
directly.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants