Skip to content

Commit

Permalink
[TKW] Test multiple igemm layouts (#201)
Browse files Browse the repository at this point in the history
* Test `nchw_fchw` and `nhwc_hwcf` igemm conv layouts.
* Perf test will use `nhwc_hwcf` as IREE seems to produce the best
result for it.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Oct 7, 2024
1 parent 5ef7ff2 commit 21801d2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 149 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def compile_and_invoke(
if config.get("print_ir_after_all", False):
flags.append("--mlir-print-ir-after-all")

preprocessing_pipeline = config.get("iree_preprocessing_pass_pipeline", None)
if preprocessing_pipeline is not None:
flags.append(f"--iree-preprocessing-pass-pipeline={preprocessing_pipeline}")

if "dump_intermediates" in config:
intermediates_path = config.get("dump_intermediates")
flags.append(
Expand Down
12 changes: 10 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,22 @@ def pytest_configure(config):
config.addinivalue_line(
"markers", "perf_only: performance test, runs only with '--runperf'"
)
config.addinivalue_line(
"markers", "validate_only: validation test, never runs with '--runperf'"
)


def _has_marker(item, marker):
return next(item.iter_markers(marker), None) is not None


def pytest_collection_modifyitems(config, items):
run_perf = config.getoption("--runperf")
for item in items:
is_perf_only = next(item.iter_markers("perf_only"), None) is not None
is_validate_only = _has_marker(item, "validate_only")
is_perf_only = _has_marker(item, "perf_only")
if run_perf:
if not is_perf_only:
if not is_perf_only or is_validate_only:
item.add_marker(pytest.mark.skip("skip non-perf test"))
else:
if is_perf_only:
Expand Down
190 changes: 43 additions & 147 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,162 +657,36 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:


_igemm_cases = [
(4, 5, 5, 10, 2, 2, 16, 3),
(2, 5, 5, 10, 2, 2, 16, 3),
(1, 5, 5, 10, 2, 2, 16, 3),
(4, 5, 5, 4, 2, 2, 16, 3),
(1, 5, 5, 4, 2, 2, 16, 3),
(1, 5, 5, 3, 2, 2, 16, 3),
(2, 5, 5, 1, 2, 2, 16, 3),
(4, 5, 5, 10, 2, 2, 2, 3),
(2, 5, 5, 10, 2, 2, 2, 3),
(1, 5, 5, 10, 2, 2, 2, 3),
(4, 5, 5, 4, 2, 2, 2, 3),
(2, 5, 5, 4, 2, 2, 2, 3),
(1, 5, 5, 3, 2, 2, 2, 3),
(2, 5, 5, 1, 2, 2, 2, 3),
(1, 5, 5, 1, 2, 2, 2, 3),
(4, 5, 5, 10, 2, 2, 1, 3),
(2, 5, 5, 10, 2, 2, 1, 3),
(1, 5, 5, 10, 2, 2, 1, 3),
(4, 5, 5, 4, 2, 2, 1, 3),
(2, 5, 5, 4, 2, 2, 1, 3),
(1, 5, 5, 4, 2, 2, 1, 3),
(2, 5, 5, 3, 2, 2, 1, 3),
(4, 5, 5, 1, 2, 2, 1, 3),
(2, 5, 5, 1, 2, 2, 1, 3),
(1, 5, 5, 1, 2, 2, 1, 3),
(4, 5, 5, 10, 2, 2, 16, 2),
(2, 5, 5, 10, 2, 2, 16, 2),
(1, 5, 5, 10, 2, 2, 16, 2),
(4, 5, 5, 4, 2, 2, 16, 2),
(1, 5, 5, 4, 2, 2, 16, 2),
(4, 5, 5, 3, 2, 2, 16, 2),
(4, 5, 5, 1, 2, 2, 16, 2),
(1, 5, 5, 1, 2, 2, 16, 2),
(4, 5, 5, 10, 2, 2, 2, 2),
(2, 5, 5, 10, 2, 2, 2, 2),
(1, 5, 5, 10, 2, 2, 2, 2),
(4, 5, 5, 4, 2, 2, 2, 2),
(2, 5, 5, 4, 2, 2, 2, 2),
(2, 5, 5, 3, 2, 2, 2, 2),
(2, 5, 5, 1, 2, 2, 2, 2),
(1, 5, 5, 1, 2, 2, 2, 2),
(4, 5, 5, 10, 2, 2, 1, 2),
(2, 5, 5, 10, 2, 2, 1, 2),
(1, 5, 5, 10, 2, 2, 1, 2),
(4, 5, 5, 4, 2, 2, 1, 2),
(2, 5, 5, 4, 2, 2, 1, 2),
(1, 5, 5, 4, 2, 2, 1, 2),
(4, 5, 5, 1, 2, 2, 1, 2),
(1, 5, 5, 1, 2, 2, 1, 2),
(4, 5, 5, 10, 2, 2, 16, 1),
(2, 5, 5, 10, 2, 2, 16, 1),
(4, 5, 5, 4, 2, 2, 16, 1),
(2, 5, 5, 4, 2, 2, 16, 1),
(1, 5, 5, 4, 2, 2, 16, 1),
(4, 5, 5, 3, 2, 2, 16, 1),
(1, 5, 5, 3, 2, 2, 16, 1),
(2, 5, 5, 1, 2, 2, 16, 1),
(1, 5, 5, 1, 2, 2, 16, 1),
(2, 5, 5, 3, 2, 2, 1, 1),
(4, 5, 5, 10, 2, 2, 2, 1),
(2, 5, 5, 10, 2, 2, 1, 1),
(2, 5, 5, 10, 2, 2, 2, 1),
(1, 5, 5, 10, 2, 2, 2, 1),
(4, 5, 5, 4, 2, 2, 2, 1),
(2, 5, 5, 4, 2, 2, 2, 1),
(1, 5, 5, 10, 2, 2, 16, 1),
(1, 5, 5, 10, 2, 2, 1, 2),
(1, 5, 5, 4, 2, 2, 2, 1),
(1, 5, 5, 3, 2, 2, 2, 1),
(2, 5, 5, 1, 2, 2, 2, 1),
(1, 5, 5, 1, 2, 2, 2, 1),
(4, 5, 5, 10, 2, 2, 1, 1),
(2, 5, 5, 10, 2, 2, 1, 1),
(4, 5, 5, 4, 2, 2, 1, 1),
(2, 5, 5, 4, 2, 2, 1, 1),
(1, 5, 5, 4, 2, 2, 1, 1),
(2, 5, 5, 1, 2, 2, 1, 1),
(1, 5, 5, 1, 2, 2, 1, 1),
(4, 5, 5, 10, 2, 2, 16, 3),
(2, 5, 5, 10, 2, 2, 16, 3),
(1, 5, 5, 10, 2, 2, 16, 3),
(4, 5, 5, 4, 2, 2, 16, 3),
(2, 5, 5, 4, 2, 2, 16, 3),
(1, 5, 5, 4, 2, 2, 16, 3),
(4, 5, 5, 1, 2, 2, 16, 3),
(1, 5, 5, 1, 2, 2, 16, 3),
(4, 5, 5, 10, 2, 2, 2, 3),
(1, 5, 5, 10, 2, 2, 2, 3),
(2, 5, 5, 4, 2, 2, 2, 3),
(1, 5, 5, 4, 2, 2, 2, 3),
(2, 5, 5, 3, 2, 2, 2, 3),
(4, 5, 5, 1, 2, 2, 2, 3),
(2, 5, 5, 1, 2, 2, 2, 3),
(1, 5, 5, 1, 2, 2, 2, 3),
(4, 5, 5, 10, 2, 2, 1, 3),
(2, 5, 5, 10, 2, 2, 1, 3),
(1, 5, 5, 10, 2, 2, 1, 3),
(4, 5, 5, 4, 2, 2, 1, 3),
(2, 5, 5, 4, 2, 2, 1, 3),
(1, 5, 5, 4, 2, 2, 1, 3),
(4, 5, 5, 1, 2, 2, 1, 3),
(2, 5, 5, 1, 2, 2, 1, 3),
(1, 5, 5, 1, 2, 2, 1, 3),
(4, 5, 5, 10, 2, 2, 16, 2),
(2, 5, 5, 10, 2, 2, 16, 2),
(1, 5, 5, 10, 2, 2, 16, 2),
(4, 5, 5, 4, 2, 2, 16, 2),
(1, 5, 5, 4, 2, 2, 16, 2),
(4, 5, 5, 1, 2, 2, 16, 2),
(2, 5, 5, 1, 2, 2, 16, 2),
(4, 5, 5, 10, 2, 2, 2, 2),
(2, 5, 5, 10, 2, 2, 2, 2),
(1, 5, 5, 10, 2, 2, 2, 2),
(4, 5, 5, 4, 2, 2, 2, 2),
(2, 5, 5, 4, 2, 2, 2, 2),
(1, 5, 5, 4, 2, 2, 2, 2),
(1, 5, 5, 3, 2, 2, 2, 2),
(2, 5, 5, 1, 2, 2, 2, 2),
(1, 5, 5, 1, 2, 2, 2, 2),
(2, 5, 5, 10, 2, 2, 1, 2),
(1, 5, 5, 10, 2, 2, 1, 2),
(4, 5, 5, 4, 2, 2, 1, 2),
(2, 5, 5, 4, 2, 2, 1, 2),
(1, 5, 5, 4, 2, 2, 1, 2),
(1, 5, 5, 3, 2, 2, 1, 2),
(2, 5, 5, 1, 2, 2, 1, 2),
(1, 5, 5, 1, 2, 2, 1, 2),
(4, 5, 5, 10, 2, 2, 16, 1),
(2, 5, 5, 10, 2, 2, 16, 1),
(1, 5, 5, 10, 2, 2, 16, 1),
(4, 5, 5, 4, 2, 2, 16, 1),
(2, 5, 5, 4, 2, 2, 16, 1),
(1, 5, 5, 4, 2, 2, 16, 1),
(2, 5, 5, 3, 2, 2, 16, 1),
(1, 5, 5, 3, 2, 2, 16, 1),
(4, 5, 5, 1, 2, 2, 16, 1),
(1, 5, 5, 1, 2, 2, 16, 1),
(2, 5, 5, 4, 2, 2, 1, 3),
(2, 5, 5, 4, 2, 2, 2, 1),
(1, 5, 5, 10, 2, 2, 16, 3),
(4, 5, 5, 4, 2, 2, 16, 2),
(4, 5, 5, 10, 2, 2, 2, 1),
(1, 5, 5, 10, 2, 2, 2, 1),
(4, 5, 5, 3, 2, 2, 1, 1),
(4, 5, 5, 4, 2, 2, 2, 1),
(2, 5, 5, 4, 2, 2, 2, 1),
(1, 5, 5, 4, 2, 2, 2, 1),
(4, 5, 5, 3, 2, 2, 2, 1),
(4, 5, 5, 1, 2, 2, 2, 1),
(2, 5, 5, 1, 2, 2, 2, 1),
(1, 5, 5, 1, 2, 2, 2, 1),
(4, 5, 5, 10, 2, 2, 1, 1),
(2, 5, 5, 10, 2, 2, 1, 1),
(4, 5, 5, 4, 2, 2, 1, 1),
(2, 5, 5, 4, 2, 2, 1, 1),
(1, 5, 5, 4, 2, 2, 1, 1),
(4, 5, 5, 3, 2, 2, 1, 1),
(2, 5, 5, 3, 2, 2, 1, 1),
(1, 5, 5, 3, 2, 2, 1, 1),
(2, 5, 5, 1, 2, 2, 1, 1),
(1, 5, 5, 1, 2, 2, 1, 1),
(2, 5, 5, 1, 2, 2, 1, 3),
(2, 5, 5, 4, 2, 2, 2, 1),
(2, 5, 5, 10, 2, 2, 16, 1),
(1, 5, 5, 1, 3, 3, 1, 1),
]

perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only)
validation_test = lambda *a: pytest.param(*a, marks=pytest.mark.validate_only)

_igemm_cases += [
perf_test(2, 128, 128, 16, 3, 3, 320, 1),
Expand All @@ -838,15 +712,21 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
]

_mem_spaces = [
pytest.param(GLOBAL_ADDRESS_SPACE, id="global"),
pytest.param(GLOBAL_ADDRESS_SPACE, id="global", marks=pytest.mark.validate_only),
pytest.param(SHARED_ADDRESS_SPACE, id="shared"),
]

_layouts = [
pytest.param("nchw_fchw", marks=pytest.mark.validate_only),
pytest.param("nhwc_hwcf"),
]


@require_e2e
@pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases)
@pytest.mark.parametrize("mem_space", _mem_spaces)
def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, request):
@pytest.mark.parametrize("layout", _layouts)
def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request):
cf = c
padding = 0 # TODO: only pad=0 is supported for now

Expand Down Expand Up @@ -907,6 +787,20 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, request):
# Other hyperparameters
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD

if layout == "nchw_fchw":
x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32]
elif layout == "nhwc_hwcf":
x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32]
x = torch.permute(x, (0, 2, 3, 1)).contiguous()
we = torch.permute(we, (2, 3, 1, 0)).contiguous()
out_ref = torch.permute(out_ref, (0, 2, 3, 1)).contiguous()
else:
raise ValueError(f"Invalid layout: {layout}")

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
Expand All @@ -924,9 +818,9 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, request):

@tkw.wave(constraints)
def conv(
x: tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16],
we: tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16],
out: tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32],
x: x_type,
we: we_type,
out: out_type,
):
c_reg = tkl.Register[M, NF, tkl.f32](0.0)

Expand All @@ -949,8 +843,6 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD
)

out = torch.zeros_like(out_ref)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

run_bench = request.config.getoption("--runperf")
Expand Down Expand Up @@ -984,6 +876,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
run_bench=run_bench,
run_config=config,
):
out = torch.zeros_like(out_ref)
conv(x, we, out)
assert_allclose(out, out_ref, rtol=1e-03, atol=1e-03)

Expand All @@ -993,9 +886,12 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
dump_perf, "iree_" + perf_filename
)

config[
"iree_preprocessing_pass_pipeline"
] = "builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)"
iree_ref = torch.zeros_like(out_ref)
generate_iree_ref(
"conv_2d_nchw_fchw",
"conv_2d_" + layout,
[x, we],
[iree_ref],
config,
Expand Down

0 comments on commit 21801d2

Please sign in to comment.