Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code to construct pipelined loop from schedule #160

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,90 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: return


@run_test
def test_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def gemm_pipelined(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
M: 128,
N: 128,
K: 128,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
},
canonicalize=True,
schedule=True,
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(gemm_pipelined(a, b, c).module_op)

# CHECK: func.func @gemm_pipelined
# CHECK-COUNT-2: vector.load
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-10: vector.load
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: scf.for
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-10: vector.load
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: scf.yield
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-8: vector.load
# CHECK-COUNT-8: amdgpu.mfma


@run_test
def test_add_float():
constraints: list[tkw.Constraint] = [
Expand Down
227 changes: 227 additions & 0 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.promotion import promote_placeholders
from shark_turbine.kernel.wave.hoisting import hoist_allocs
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import *
from shark_turbine.kernel.wave.utils import run_test, print_subgraph
from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from shark_turbine.kernel.wave.shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
)
from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph


# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K

# Address space
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0

# Induction variable for dimension K
ARGK = tkl.sym.ARGK


@tkw.wave_trace_only()
def gemm_pipelined(
a: tkl.Memory[M, K, ADDRESS_SPACE_0, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE_0, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=4)


@run_test
def test_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
]
with tk.gen.TestLaunchContext(
{
M: 128,
N: 256,
K: 128,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
ADDRESS_SPACE_0: SHARED_ADDRESS_SPACE,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 2,
GLOBAL_MEMORY_UNITS: 2,
MMA_UNITS: 2,
}
):
trace: CapturedTrace = gemm_pipelined()
IndexingContext.current().finalize()
promote_placeholders(trace, constraints)
hoist_allocs(trace)
expand_graph(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
schedule_graph(trace, constraints)

print_subgraph(trace, "pipelined_reduction", False)
# CHECK: %acc_0_0_0
# CHECK-NEXT: %acc_0_1_0
# CHECK-NEXT: %acc_1_0_0
# CHECK-NEXT: %acc_1_1_0
# CHECK-NEXT: %rotating_reg_0
# CHECK-NEXT: %rotating_reg_1
# CHECK-NEXT: %rotating_reg_2
# CHECK-NEXT: %rotating_reg_3
# CHECK-NEXT: %rotating_reg_4
# CHECK-NEXT: %rotating_reg_5
# CHECK-NEXT: %rotating_reg_6
# CHECK-NEXT: %mma_1_1_1
# CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6)
# CHECK-NEXT: %read_shared_0_0_0
# CHECK-NEXT: %read_shared_0_0_1
# CHECK-NEXT: %read_4
# CHECK-NEXT: %read_5
# CHECK-NEXT: %read_shared_1_0_0
# CHECK-NEXT: %read_shared_1_0_1
# CHECK-NEXT: %mma_0_0_0
# CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0)
# CHECK-NEXT: %mma_0_1_0
# CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0)
# CHECK-NEXT: %mma_0_0_1
# CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0)
# CHECK-NEXT: %mma_1_0_0
# CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0)
# CHECK-NEXT: %write_2
# CHECK-NEXT: %write_3
# CHECK-NEXT: %mma_1_0_1
# CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0)
# CHECK-NEXT: %mma_0_1_1
# CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0)
# CHECK-NEXT: %read_shared_0_1_0
# CHECK-NEXT: %read_shared_0_1_1
# CHECK-NEXT: %mma_1_1_0
# CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1)
# CHECK-NEXT: %read_shared_0_0_2
# CHECK-NEXT: %read_shared_0_0_3
# CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0]

print_subgraph(trace, "region_1", False)
# CHECK: %a
# CHECK-NEXT: %b
# CHECK-NEXT: %c
# CHECK-NEXT: %register_0_0_0
# CHECK-NEXT: %register_1_1_0
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %allocate
# CHECK-NEXT: %allocate_1
# CHECK-NEXT: %read_4
# CHECK-NEXT: %read_5
# CHECK-NEXT: %write_2
# CHECK-NEXT: %write_3
# CHECK-NEXT: %read_shared_0_1_0
# CHECK-NEXT: %read_shared_0_1_1
# CHECK-NEXT: %read_shared_0_0_1
# CHECK-NEXT: %read_shared_0_0_2
# CHECK-NEXT: %read_shared_0_0_0
# CHECK-NEXT: %read_shared_0_0_3
# CHECK-NEXT: %read_6
# CHECK-NEXT: %read_7
# CHECK-NEXT: %read_shared_1_0_0
# CHECK-NEXT: %read_shared_1_0_1
# CHECK-NEXT: %mma_0_0_0
# CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0)
# CHECK-NEXT: %mma_0_1_0
# CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0)
# CHECK-NEXT: %mma_0_0_1
# CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0)
# CHECK-NEXT: %mma_1_0_0
# CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0)
# CHECK-NEXT: %write_4
# CHECK-NEXT: %write_5
# CHECK-NEXT: %mma_1_0_1
# CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0)
# CHECK-NEXT: %mma_0_1_1
# CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0)
# CHECK-NEXT: %read_shared_0_1_2
# CHECK-NEXT: %read_shared_0_1_3
# CHECK-NEXT: %mma_1_1_0
# CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0)
# CHECK-NEXT: %read_shared_0_0_4
# CHECK-NEXT: %read_shared_0_0_5
# CHECK-NEXT: %reduction_1
# CHECK-NEXT: %getresult_1_1_0
# CHECK-NEXT: %getresult_1_0_0
# CHECK-NEXT: %getresult_0_1_0
# CHECK-NEXT: %getresult_0_0_0
# CHECK-NEXT: %get_result_4
# CHECK-NEXT: %get_result_5
# CHECK-NEXT: %get_result_6
# CHECK-NEXT: %get_result_7
# CHECK-NEXT: %get_result_8
# CHECK-NEXT: %get_result_9
# CHECK-NEXT: %get_result_10
# CHECK-NEXT: %mma_1_1_1
# CHECK-SAME: (%get_result_5, %get_result_9, %get_result_10)
# CHECK-NEXT: %read_shared_0_0_6
# CHECK-NEXT: %read_shared_0_0_7
# CHECK-NEXT: %read_shared_1_0_2
# CHECK-NEXT: %read_shared_1_0_3
# CHECK-NEXT: %mma_0_0_2
# CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0)
# CHECK-NEXT: %mma_0_1_2
# CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0)
# CHECK-NEXT: %mma_0_0_3
# CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2)
# CHECK-NEXT: %mma_1_0_2
# CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0)
# CHECK-NEXT: %mma_1_0_3
# CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2)
# CHECK-NEXT: %mma_0_1_3
# CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2)
# CHECK-NEXT: %mma_1_1_2
# CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1)
# CHECK-NEXT: %mma_1_1_3
# CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2)
# CHECK-NEXT: %write_0_0_0
# CHECK-NEXT: %write_1_1_0
# CHECK-NEXT: %write_1_0_0
# CHECK-NEXT: %write_0_1_0
# CHECK-NEXT: return None


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 3 additions & 0 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(self, region_graph: RegionGraph, root_graph: str):
def get_subgraph(self, name: str) -> fx.Graph:
return self.region_graph.subgraphs[name]

def add_subgraph(self, name: str, graph: fx.Graph):
self.region_graph.subgraphs[name] = graph

def get_root_graph(self) -> fx.Graph:
return self.get_subgraph(self.root_graph)

Expand Down
18 changes: 17 additions & 1 deletion shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def index(self, value: Any):
self.fx_node.index = {}
for dim, key in value.items():
self.fx_node.index[dim] = key
elif isinstance(value, list):
self.fx_node.index = value
else:
raise ValueError("Index must be a dict")

Expand Down Expand Up @@ -691,7 +693,7 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool:
prev_node, found_src = prev_node.prev, prev_node == src
if not found_src:
return False
while next_node and not found_dst:
while next_node.next.op != "root" and not found_dst:
next_node, found_dst = next_node.next, next_node == dst
return found_dst

Expand Down Expand Up @@ -910,6 +912,20 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
else None
)

@index.setter
def index(self, value: Any):
CustomOp.index.fset(self, value)

@property
def count(self) -> int:
if hasattr(self.fx_node, "count"):
return self.fx_node.count
return None

@count.setter
def count(self, value: int):
self.fx_node.count = value


@define_op("write")
@dataclass
Expand Down
Loading
Loading