Skip to content

Commit

Permalink
Cleanup mma expansion
Browse files Browse the repository at this point in the history
This PR removes the need for maintaining
state through function local attributes during
MMA expansion.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Dec 20, 2024
1 parent 3d668fe commit 0de2735
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _expand_node(
context,
get_node_dim_scaling,
res_idx,
node.acc,
)
elif isinstance(node, Reduction):
return _expand_reduction(
Expand Down Expand Up @@ -432,6 +433,7 @@ def _expand_mma_reduction(
context: ExpandedNodeMap,
get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]],
res_idx: int,
accumulator: CustomOp,
) -> CustomOp:
"""
This function expands an MMA node along its reduction dimension. It is called
Expand Down Expand Up @@ -468,25 +470,13 @@ def _expand_mma_reduction(
# For M = 0, K2 = 1, K1 = 0, we use the original mma node so that the last cloned node's
# accumulator value is not modified.

dim_query_dims = tuple(dim_query.keys())
if not hasattr(_expand_mma_reduction, "acc"):
_expand_mma_reduction.acc = {}
if not hasattr(_expand_mma_reduction, "mma"):
_expand_mma_reduction.mma = {}
if (
dim_query_dims not in _expand_mma_reduction.mma
or _expand_mma_reduction.mma[dim_query_dims].graph != mma.graph
):
_expand_mma_reduction.mma[dim_query_dims] = mma
_expand_mma_reduction.acc[dim_query_dims] = mma.acc

context_key = (
_expand_mma_reduction.mma[dim_query_dims],
mma,
get_indexed_dims(dim_query, expand_dims),
res_idx,
)

user = _expand_mma_reduction.mma[dim_query_dims]
user = mma
for scale_idx in range(dim_scaling[mma.reduction_dim]):
if isinstance(user, Output):
continue
Expand Down Expand Up @@ -521,7 +511,7 @@ def _expand_mma_reduction(
if scale_idx > 0:
new_node.update_arg(index, user)
else:
new_node.update_arg(index, _expand_mma_reduction.acc[dim_query_dims])
new_node.update_arg(index, accumulator)
user.update_arg(index, saved_arg)
user.graph.erase_node(dummy)
user = new_node
Expand Down

0 comments on commit 0de2735

Please sign in to comment.