From 909411a5e47f41ff70368a1acba2de5cafea9440 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:45:27 -0700 Subject: [PATCH] [TKW] Fix indexing of Reduction and GetResult to enable post-tile op. (#162) This PR introduces changes to handle elementwise or general arithmetic operations after we did some tiled-loop-reduction ("Reduction") operation. The main problem with the current stack is indexing_dims information for Reduction relies on the user. This would work if it's user/consumer is tkw.write, but in other cases such as BinaryPyOp or UnaryPyOp, it will lack such information. To make matters worst BinaryPyOp/UnaryPyOp depends on it's src/producer for indexing dim, while Reduction op depends on it's dst/consumer for its' indexing dim information. This would ended up causing infinite loop between UnaryPyOp/BinaryPyOp <-> Reduction. This PR fixes the indexing dimension logic Reduction and GetResult (required for expanded Reduction) to be based on it's reduction axis(for Reduction) and it's source/consumer information. --------- Signed-off-by: Stanley Winata --- lit_tests/kernel/wave/codegen.py | 88 ++++++++++++++++++++++++++ shark_turbine/kernel/ops/wave_ops.py | 45 +++++++++---- shark_turbine/kernel/wave/expansion.py | 5 ++ 3 files changed, 125 insertions(+), 13 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 0bba2384..b84cc271 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -756,6 +756,94 @@ def test( # CHECK: arith.addf {{.*}} : vector<1xf16> +# This test is to ensure that the propagation of indexing_dims between reduction and operations +# outside the reduction is working properly. +@run_test +def test_reduction_and_elemwise(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(N, init_args=[init_max]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + return partial_max + + result = repeat + repeat + tkw.write(result, c, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT]], %[[ACC1:.+]] = %[[INIT]]) -> (vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_1:.+]] = arith.maximumf %[[ACC1]], %{{.*}} + + # CHECK: scf.yield %[[ACC_REDUCE_0]], %[[ACC_REDUCE_1]] : vector<1xf16>, vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_0:.+]] = arith.addf %[[TILED]]#0, %[[TILED]]#0 : vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_1:.+]] = arith.addf %[[TILED]]#1, %[[TILED]]#1 : vector<1xf16> + # CHECK: vector.store %[[POST_TILE_ELEMWISE_0:.+]], %{{.*}} + # CHECK: vector.store %[[POST_TILE_ELEMWISE_1:.+]], %{{.*}} + + @run_test def test_tiled_reduce_max(): M = tkl.sym.M diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 3a2d3d3b..ebadf0c4 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -861,12 +861,23 @@ def wrapper(f): return wrapper @property - def indexing_dims(self) -> list[IndexSymbol]: + def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]: expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) + return_node = [ + nested_node + for nested_node in self.graph.subgraphs[self.subgraph_name].nodes + if isinstance(get_custom(nested_node), Output) + ] + assert len(return_node) == 1 + return_vals = get_custom(return_node[0]).return_vals[0] + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for return_val in return_vals: + return_dims = get_custom(return_val).indexing_dims + reduced_dims = [dims for dims in return_dims if dims != self.axis] + expand_dims.append(reduced_dims) + if len(expand_dims) == 1: + expand_dims = expand_dims[0] return expand_dims def iter_args(self, graph: fx.Graph) -> list[fx.Node]: @@ -952,16 +963,24 @@ class GetResult(CustomOp): @property def type(self) -> "Memory": - return get_custom(self.value).type[self.res_idx] + src_type = get_custom(self.value).type + if isinstance(src_type, list): + return src_type[self.res_idx] + else: + return src_type @property - def indexing_dims(self) -> list[IndexSymbol]: - expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) - return expand_dims + def indexing_dims(self) -> list[IndexExpr]: + has_multiple_value = lambda x: all(isinstance(el, list) for el in x) + is_valid_indexing_dim = lambda x: isinstance(src_indexing, list) and all( + isinstance(el, IndexExpr) for el in x + ) + src_indexing = get_custom(self.value).indexing_dims + if has_multiple_value(src_indexing): + assert self.res_idx <= len(src_indexing) - 1 + src_indexing = src_indexing[self.res_idx] + assert is_valid_indexing_dim(src_indexing) + return src_indexing @property def index(self) -> dict[IndexSymbol, IndexSequence]: diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index 53796682..2610f968 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -81,6 +81,11 @@ def get_indexed_dims( """ if isinstance(nodeOrDims, CustomOp): nodeOrDims = nodeOrDims.indexing_dims + # Flatten dims for node with multiple values or expanded Reduction. + if all(isinstance(el, Sequence) for el in nodeOrDims): + flattened_dims = list(itertools.chain.from_iterable(nodeOrDims)) + flatten_dims_set = dict.fromkeys(flattened_dims) + nodeOrDims = list(flatten_dims_set) return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims)