Skip to content

Commit

Permalink
Get GEMMs working without minimize_global_loads (iree-org#167)
Browse files Browse the repository at this point in the history
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

---------

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Ian <ian.nordeng@amd.com>
  • Loading branch information
harsh-nod authored and IanNod committed Sep 30, 2024
1 parent 9ea17c0 commit 3e0e3b8
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 135 deletions.
185 changes: 96 additions & 89 deletions lit_tests/kernel/wave/codegen.py

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,23 +243,23 @@ def test_gemm():
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})
# CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b])
# CHECK-NEXT: get_result(value=reduction, res_idx=3)
# CHECK-NEXT: get_result(value=reduction, res_idx=2)
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_1_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: write(register_=getresult_1_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_0_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: output

# Reduction subgraph:
Expand Down Expand Up @@ -389,11 +389,11 @@ def test_gemm_reduction_expansion_only():
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0]
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: output(return_vals=(None,))

# Reduction subgraph:
Expand Down
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def test_gemm():
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: allocate(
# CHECK-NEXT: allocate(
# CHECK-NEXT: reduction(
Expand Down
16 changes: 8 additions & 8 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def test_gemm():
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: allocate(
# CHECK-NEXT: allocate(
# CHECK-NEXT: reduction(
Expand All @@ -146,13 +146,13 @@ def test_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_1_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: write(register_=getresult_1_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_0_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})

# Reduction subgraph:
# CHECK: %acc_0_0_0
Expand Down
10 changes: 0 additions & 10 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,16 +780,6 @@ def custom_string(self, value_map: dict[str, str]) -> str:
custom_str += f"acc={self.acc} (index = {self.acc_index}))"
return custom_str

def post_expansion(self, constraints: list["Constraint"]) -> None:
"""
Once the arguments have been expanded, we set their indices,
ensuring that the LHS and RHS indices are consistent with their
corresponding address spaces.
"""
self.lhs.index = self.lhs_index
self.rhs.index = self.rhs_index
self.acc.index = self.acc_index


@define_op("read")
@dataclass
Expand Down
15 changes: 7 additions & 8 deletions shark_turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .._support.indexing import IndexingContext, IndexSequence
from ...support.logging import get_logger
from .._support.tracing import CapturedTrace
from .utils import get_mma_dimensional_mapping
from .utils import get_mma_dimensional_mapping, specialize_index_sequence
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.expansion")
Expand Down Expand Up @@ -146,6 +146,7 @@ def compute_stride(
def set_node_index(
constraints: Sequence[Constraint],
mma_index: dict[IndexSymbol, int],
mma_slices: dict[IndexSymbol, list[fx.Node]],
dim_tile_size: dict[IndexSymbol, int],
custom: CustomOp,
dim_scaling: dict[IndexSymbol, int],
Expand Down Expand Up @@ -176,11 +177,7 @@ def set_node_index(
for dim in custom.indexing_dims:
index_seq = None
for constraint in sorted_constraints:
mma_check = (
isinstance(constraint, HardwareConstraint)
and dim in mma_index
and isinstance(custom, MMA)
)
mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index

vector_check = (
isinstance(constraint, HardwareConstraint)
Expand Down Expand Up @@ -222,6 +219,8 @@ def set_node_index(
index_seq = constraint.apply(
constraint_index, dim, elements_per_thread, stride
)
if mma_index:
index_seq = specialize_index_sequence(index_seq, mma_slices, custom)

else:
if index_seq is None:
Expand Down Expand Up @@ -251,10 +250,10 @@ def expand_graph(
dim_scaling = constraints_or_scaling
node_index_setter = lambda *args: None
else:
mma_index = get_mma_dimensional_mapping(trace)
mma_index, mma_slices = get_mma_dimensional_mapping(trace)
dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index)
node_index_setter = partial(
set_node_index, constraints_or_scaling, mma_index, dim_tile_size
set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size
)

# Start from the back and expand in the corresponding indexing dimensions of a node
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_vector_shape(
hardware_constraint: HardwareConstraint,
symbolic_shape: list[IndexSymbol],
) -> list[int]:
mma_indices = get_mma_dimensional_mapping(trace)
mma_indices, _ = get_mma_dimensional_mapping(trace)
return [
get_hardware_vector_size(dim, hardware_constraint, mma_indices)
for dim in symbolic_shape
Expand Down
150 changes: 145 additions & 5 deletions shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -16,7 +15,16 @@
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence
from ..lang.global_symbols import *
from ..ops.wave_ops import get_custom, Output, Write, MMA
from ..ops.wave_ops import (
get_custom,
Output,
Write,
MMA,
CustomOp,
Reduction,
GetResult,
IterArg,
)
from .constraints import Constraint, HardwareConstraint, TilingConstraint
import torch.fx as fx
import shark_turbine.kernel.lang as tkl
Expand Down Expand Up @@ -145,7 +153,9 @@ def simplify_index(index: IndexExpr) -> IndexExpr:
return subs_idxc(index.subs(mapping))


def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]:
def get_mma_dimensional_mapping(
trace: CapturedTrace,
) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]:
"""
Given a trace, determine the MMA dimensional mapping for all the
MMA operations in the graph. For example, if we have
Expand All @@ -159,7 +169,8 @@ def is_mma(node):
return isinstance(get_custom(node), MMA)

mapping: dict[IndexSymbol, int] = {}
for node in trace.walk(is_mma):
mma_nodes = trace.walk(is_mma)
for node in mma_nodes:
custom: MMA = get_custom(node)
m, n = custom.acc_type.symbolic_shape[-2:]
lhs_shape = custom.lhs_type.symbolic_shape
Expand All @@ -170,7 +181,7 @@ def is_mma(node):
mapping[n] = 1
mapping[k] = 2

return mapping
return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes])


def get_hardware_vector_size(
Expand Down Expand Up @@ -378,3 +389,132 @@ def erase_graph(graph: fx.Graph):
for user in node.users:
graph.erase_node(user)
graph.erase_node(node)


def get_users(
node: fx.Node, reduction: fx.Node = None
) -> tuple[list[fx.Node], fx.Node]:
"""
Return the users of a node, propagating through reductions.
"""
users = []
for user in node.users:
custom = get_custom(user)
if isinstance(custom, Reduction):
# Map init arg to iter arg
reduction = custom
init_arg_idx = custom.init_args.index(node)
users.append(custom.iter_args[init_arg_idx])
continue
if isinstance(custom, Output) and reduction:
# Map output to get result
return_vals = custom.return_vals[0]
get_results = sorted(
[x for x in reduction.users if isinstance(get_custom(x), GetResult)],
lambda x: get_custom(x).res_idx,
)
if isinstance(return_vals, list):
output_idx = return_vals.index(node)
users.append(get_results[output_idx])
else:
users.append(get_results[0])
continue
users.append(user)
return users, reduction


def get_inputs(
node: fx.Node, reduction: fx.Node = None
) -> tuple[list[fx.Node], fx.Node]:
"""
Return the inputs of a node, propagating through reductions.
"""
inputs = []
for input in node.all_input_nodes:
custom = get_custom(input)
if isinstance(custom, GetResult):
reduction = custom.value
assert isinstance(
reduction, Reduction
), "GetResult must be used by a Reduction"
# Map get result to output
inputs.append(reduction.outputs[custom.res_idx])
continue
if isinstance(custom, IterArg):
# Map iter args to init args
iter_arg_idx = reduction.iter_args.index(node)
inputs.append(reduction.init_args[iter_arg_idx])
continue
inputs.append(input)
return inputs, reduction


def bfs(
node: fx.Node,
get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]],
) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
"""
visited: set[fx.Node] = set()
queue: list[fx.Node] = []
visited.add(node)
queue.append(node)
reduction = None
while queue:
s = queue.pop(0)
neighbors, reduction = get_neighbors(s, reduction)
for neighbor in neighbors:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return visited


def capture_forward_slice(node: fx.Node) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
"""
return bfs(node, lambda x, y: get_users(x, y))


def capture_backward_slice(node: fx.Node) -> set[fx.Node]:
"""
Capture backward slice from a node and return the tree.
Assumes graph is directed.
"""
return bfs(node, lambda x, y: get_inputs(x, y))


def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]:
"""
Given an index sequence, specialize it to a LHS, RHS or ACC index sequence
based on whether the node is used as the LHS, RHS or ACC in the MMA node.
"""
mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]}
for mma in mma_nodes:
mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs)
mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs)
mma_slices[MMA_ACC] += capture_forward_slice(mma.acc)
return mma_slices


def specialize_index_sequence(
index_seq: IndexSequence,
mma_slices: dict[IndexSymbol, list[fx.Node]],
custom: CustomOp,
) -> IndexSequence:
"""
Given an index sequence, specialize it to a LHS, RHS or ACC index sequence
based on whether the node is used as the LHS, RHS or ACC in the MMA node.
If the node is not used as any of the operands, return the original index sequence
with all the MMA symbols zeroed out.
"""
if isinstance(custom, MMA):
return index_seq
operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0}
for key in mma_slices:
if custom.fx_node in mma_slices[key]:
operand_map[key] = 1
return index_seq.subs(operand_map)
return index_seq.subs(operand_map)

0 comments on commit 3e0e3b8

Please sign in to comment.