From 0de2735b1c947ff283a6e31c38697a221d888a0e Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Thu, 19 Dec 2024 21:12:17 -0800 Subject: [PATCH] Cleanup mma expansion This PR removes the need for maintaining state through function local attributes during MMA expansion. Signed-off-by: Harsh Menon --- iree/turbine/kernel/wave/expansion.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 346d7529..212e133b 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -221,6 +221,7 @@ def _expand_node( context, get_node_dim_scaling, res_idx, + node.acc, ) elif isinstance(node, Reduction): return _expand_reduction( @@ -432,6 +433,7 @@ def _expand_mma_reduction( context: ExpandedNodeMap, get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int, + accumulator: CustomOp, ) -> CustomOp: """ This function expands an MMA node along its reduction dimension. It is called @@ -468,25 +470,13 @@ def _expand_mma_reduction( # For M = 0, K2 = 1, K1 = 0, we use the original mma node so that the last cloned node's # accumulator value is not modified. - dim_query_dims = tuple(dim_query.keys()) - if not hasattr(_expand_mma_reduction, "acc"): - _expand_mma_reduction.acc = {} - if not hasattr(_expand_mma_reduction, "mma"): - _expand_mma_reduction.mma = {} - if ( - dim_query_dims not in _expand_mma_reduction.mma - or _expand_mma_reduction.mma[dim_query_dims].graph != mma.graph - ): - _expand_mma_reduction.mma[dim_query_dims] = mma - _expand_mma_reduction.acc[dim_query_dims] = mma.acc - context_key = ( - _expand_mma_reduction.mma[dim_query_dims], + mma, get_indexed_dims(dim_query, expand_dims), res_idx, ) - user = _expand_mma_reduction.mma[dim_query_dims] + user = mma for scale_idx in range(dim_scaling[mma.reduction_dim]): if isinstance(user, Output): continue @@ -521,7 +511,7 @@ def _expand_mma_reduction( if scale_idx > 0: new_node.update_arg(index, user) else: - new_node.update_arg(index, _expand_mma_reduction.acc[dim_query_dims]) + new_node.update_arg(index, accumulator) user.update_arg(index, saved_arg) user.graph.erase_node(dummy) user = new_node