Skip to content

Commit

Permalink
Update run configurations for gemm test
Browse files Browse the repository at this point in the history
Also add the ability to output schedule files
and use user-modified schedule files.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 9, 2024
1 parent f207ca5 commit 0f76d54
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 33 deletions.
57 changes: 57 additions & 0 deletions iree/turbine/kernel/wave/scheduling/modulo_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable
import numpy as np
import math
import csv

logger = get_logger("turbine.wave.modulo_scheduling")

Expand Down Expand Up @@ -272,3 +273,59 @@ def num_stages(self) -> int:
"""
max_cycle = max([t for t in self.schedule.values()])
return math.ceil(max_cycle / self.initiation_interval)

def load_schedule(self, path: str, graph: fx.Graph) -> None:
"""
Load a schedule into the scheduler.
The schedule consists of a mapping from nodes in the graph to cycles.
The nodes must be in the same order as the graph.
"""
self._initiation_interval = 0
self.schedule: dict[fx.Node, int] = {}
print(f"Loading schedule from: {path}.\n")
data = []
with open(path, "r") as file:
schedule_reader = csv.reader(file, delimiter=",")
for row in schedule_reader:
data.append(
{
"cycle": int(row[0]),
"name": row[1],
"initiation_interval": int(row[2]),
"relative_cycle": int(row[3]),
"graph_index": int(row[4]),
}
)
data.sort(key=lambda x: x["graph_index"])
for i, node in enumerate(graph.nodes):
self.schedule[node] = data[i]["cycle"]
self._initiation_interval = data[i]["initiation_interval"]
print(f"Loaded schedule for node: {node.name} -> {self.schedule[node]}.")
print(f"Set initiation interval: {self._initiation_interval}.")

def save_schedule(self, path: str, graph: fx.Graph) -> None:
"""
Save the schedule to a file. First, assign an index to each node in the graph
to specify the order in which they should be loaded. The schedule format is:
# cycle, node name, initiation interval, cycle % initiation interval, graph index.
Only the cycle should be modified by users.
"""
nodes_with_metadata = []
for i, node in enumerate(graph.nodes):
nodes_with_metadata.append((i, node, self.schedule[node]))
nodes_with_metadata.sort(key=lambda x: x[2])

with open(path, "w") as file:
schedule_writer = csv.writer(file, delimiter=",")
for graph_index, node, cycle in nodes_with_metadata:
assert node in self.schedule, f"Node {node} not scheduled."
schedule_writer.writerow(
[
cycle,
node.name,
self._initiation_interval,
cycle % self._initiation_interval,
graph_index,
]
)
logger.info(f"Saved schedule to: {path}.")
23 changes: 21 additions & 2 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def schedule_reduction(
trace: CapturedTrace,
constraints: list[Constraint],
use_scheduling_barriers: bool = False,
export_schedule_file: str = None,
user_specified_schedule_file: str = None,
):
"""
Clones the reduction graph and does the following:
Expand All @@ -49,7 +51,17 @@ def schedule_reduction(
visualize_graph(graph, "scheduling_fx_graph.png")

scheduler = ModuloScheduler(graph, edges, get_available_resources())
schedule, success = scheduler.schedule_graph()

if user_specified_schedule_file:
scheduler.load_schedule(user_specified_schedule_file, graph)
success = True
schedule = scheduler.schedule
else:
schedule, success = scheduler.schedule_graph()

if export_schedule_file:
scheduler.save_schedule(export_schedule_file, graph)

if not success:
raise ValueError("Scheduling failed.")
if visualize:
Expand Down Expand Up @@ -107,6 +119,8 @@ def schedule_graph(
trace: CapturedTrace,
constraints: list[Constraint],
use_scheduling_barriers: bool = False,
export_schedule_file: str = None,
user_specified_schedule_file: str = None,
):
"""
Given a graph, pipelines the reductions in the graph.
Expand All @@ -121,5 +135,10 @@ def is_reduction(node: fx.Node) -> bool:

for reduction_node in reduction_nodes:
schedule_reduction(
get_custom(reduction_node), trace, constraints, use_scheduling_barriers
get_custom(reduction_node),
trace,
constraints,
use_scheduling_barriers,
export_schedule_file,
user_specified_schedule_file,
)
12 changes: 11 additions & 1 deletion iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,17 @@ def _trace_and_get_kernel_signature(
# [Manually resolve conflicts consistent with the PR]
if kwargs.get("schedule", False):
use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False)
schedule_graph(graph, self.constraints, use_scheduling_barriers)
export_schedule_file = kwargs.get("export_schedule_file", None)
user_specified_schedule_file = kwargs.get(
"user_specified_schedule_file", None
)
schedule_graph(
graph,
self.constraints,
use_scheduling_barriers,
export_schedule_file,
user_specified_schedule_file,
)

# Add shared memory barriers.
add_shared_memory_barriers(graph)
Expand Down
40 changes: 39 additions & 1 deletion tests/kernel/wave/scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph
from iree.turbine.kernel.ops.wave_ops import get_custom
from copy import deepcopy
import filecmp
import os


class SchedulingTest(unittest.TestCase):
Expand Down Expand Up @@ -271,13 +274,15 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
MMA_UNITS: 2,
}
with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True):

trace: CapturedTrace = gemm()
IndexingContext.current().finalize()
promote_placeholders(trace, constraints)
hoist_allocs(trace)
expand_graph(trace, constraints)
minimize_global_loads(trace, constraints)
schedule_graph(trace, constraints)

schedule_graph(trace, constraints, False, "schedule.csv", None)
subgraph = trace.get_subgraph("region_0")
initiation_interval = 5
correct_schedule = {
Expand Down Expand Up @@ -425,6 +430,39 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
if custom.name in correct_schedule:
assert custom.scheduling_parameters == correct_schedule[custom.name]

# TODO: Debug why gemm2 is necessary. Using gemm for a second trace
# causes a node mismatch when loading the schedule.
@tkw.wave_trace_only(constraints)
def gemm2(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True):

trace: CapturedTrace = gemm2()
IndexingContext.current().finalize()
promote_placeholders(trace, constraints)
hoist_allocs(trace)
expand_graph(trace, constraints)
minimize_global_loads(trace, constraints)

schedule_graph(
trace, constraints, False, "new_schedule.csv", "schedule.csv"
)
assert filecmp.cmp("schedule.csv", "new_schedule.csv", shallow=False)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down
81 changes: 52 additions & 29 deletions tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@
# Whether to use scheduling group barriers (needs LLVM fix).
enable_scheduling_barriers = int(os.environ.get("WAVE_USE_SCHED_BARRIERS", 0))

default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)]

default_test_shapes = [
(2048, 10240, 1280, 128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2),
(2048, 1280, 1280, 64, 64, 64, 2, 2, 1, 2, 1, 1, 1, 1, 2),
(2048, 1280, 5120, 128, 80, 128, 4, 1, 1, 4, 2, 2, 1, 1, 2),
(128, 1280, 2048, 64, 64, 128, 2, 2, 1, 8, 2, 2, 1, 1, 2),
(8192, 5120, 640, 128, 128, 32, 2, 2, 1, 4, 2, 2, 1, 1, 2),
]

perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only)

default_test_shapes += [
perf_test((1024, 5120, 640)),
perf_test((2048, 10240, 1280)),
perf_test((4096, 20480, 2560)),
]
default_test_shapes += [perf_test(x) for x in default_test_shapes]

user_specified_test_shapes = ""

Expand All @@ -50,9 +53,26 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
@pytest.mark.parametrize("params", get_test_shapes("test_gemm"))
@pytest.mark.parametrize("enable_scheduling", [False, True])
def testGemm(shape: tuple[int], enable_scheduling: bool, request):
def testGemm(params: tuple[int], enable_scheduling: bool, request):
(
m,
n,
k,
block_m,
block_n,
block_k,
ratio_m,
ratio_n,
waves_per_eu,
mma_units,
shared_units,
global_units,
delay_mma,
delay_shared,
delay_global,
) = params
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
# Input sizes
Expand All @@ -73,11 +93,13 @@ def testGemm(shape: tuple[int], enable_scheduling: bool, request):
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / ratio_n)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(ratio_m, ratio_n, 1)
)
]

# Wave-level micro-kernel.
Expand Down Expand Up @@ -113,20 +135,20 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
M: shape[0],
N: shape[1],
K: shape[2],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
BLOCK_M: block_m,
BLOCK_N: block_n,
BLOCK_K: block_k,
M: m,
N: n,
K: k,
READ_SHARED_DELAY: delay_shared,
WRITE_SHARED_DELAY: delay_shared,
READ_GLOBAL_DELAY: delay_global,
WRITE_GLOBAL_DELAY: delay_global,
MMA_DELAY: delay_mma,
SHARED_MEMORY_UNITS: shared_units,
GLOBAL_MEMORY_UNITS: global_units,
MMA_UNITS: mma_units,
}
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
if run_bench:
Expand All @@ -147,12 +169,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
schedule=enable_scheduling,
use_scheduling_barriers=enable_scheduling_barriers,
):
a = torch.randn(shape[0], shape[2], dtype=torch.float16)
b = torch.randn(shape[1], shape[2], dtype=torch.float16)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
a = torch.randn(m, k, dtype=torch.float16)
b = torch.randn(n, k, dtype=torch.float16)
c = torch.zeros(m, n, dtype=torch.float32)
mb = gemm(a, b, c)

if test_dump_generated_mlir:
shape = [m, n, k]
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
Expand All @@ -162,6 +185,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
config["benchmark_results_file"] = os.path.join(
dump_perf, "iree_" + perf_filename
)
iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32)
iree_ref = torch.zeros(m, n, dtype=torch.float32)
generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench)
assert_close(c, iree_ref)

0 comments on commit 0f76d54

Please sign in to comment.