Skip to content

Commit

Permalink
[TKW] Scalarized local reduction for faster Max (#335)
Browse files Browse the repository at this point in the history
Max reduction can make use of v_max3_f32 which is much faster than
regular v_max_f32. However to enable this we'd need to scalarize the
local reduction more S.T LLVM compiler can do this optimization better.
(drop latency from
4.5 ms to ~4.3 ms on dispatch146(B0: 2, B1: 20, (M, K2): 1024: K1: 64)

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Dec 16, 2024
1 parent 91d8f59 commit 142c8a5
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
73 changes: 66 additions & 7 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,77 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph):
def emit_sources_reduction(
binary_fn: Callable, src: list[fx.Node], graph: fx.Graph
) -> fx.Node:
"""
Does reduction over a list of fx.Node variables by applying binary_fn on them.
"""
init = src[0]
for i in range(1, len(src)):
init = get_graph_node(binary_fn(init, src[i]), graph)
init.index = src[0].index
return init


def emit_local_reduction(
def emit_variable_reduction(
binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int
) -> fx.Node:
"""
Does reduction over a singular fx.Node variable.
"""
init = get_graph_node(Extract(src, [0]), graph)
for i in range(1, local_reduction_size):
cur_slice = get_graph_node(Extract(src, [i]), graph)
init = get_graph_node(binary_fn(init, cur_slice), graph)
return init


def emit_local_reduction(
binary_fn: Callable,
reduction_src: list[fx.Node],
graph: fx.Graph,
local_reduction_size,
):
"""
Does reduction over all the element carried along by ReductionOp at local
thread/SIMT level. This is done by reducing expanded sources combining them
into single variable, and then reducing that variable into a scalar.
"""
src_reduction = emit_sources_reduction(binary_fn, reduction_src, graph)
local_reduction = emit_variable_reduction(
binary_fn, src_reduction, graph, local_reduction_size
)
return local_reduction


def emit_scalarized_local_reduction(
binary_fn: Callable,
reduction_src: list[fx.Node],
graph: fx.Graph,
local_reduction_size,
):
"""
Special case of local reduction wher we try to scalarize/get rid of most vector ops.
this is useful for maximum, to expose more opportunities for v_max3_f32,
We do this by first reducing the sources(scalar/iterative manner), and then
reducing all the "reduced" args/source.
e.g we transform from:
%source_reduce = arith.maximumf %lhs, %rhs : vector<16xf32>
%local_reduce = vector.reduction<maximumf>, %src_reduce : f32 from vector<16xf32>
into:
%local_lhs_reduce = vector.reduction<maximumf>, %lhs : f32 from vector<16xf32>
%local_rhs_reduce = vector.reduction<maximumf>, %rhs : f32 from vector<16xf32>
%local_src_reduce = arith.maximumf %local_lhs_reduce, %local_rhs_reduce : f32
"""
locally_reduced_sources = [
emit_variable_reduction(binary_fn, arg, graph, local_reduction_size)
for arg in reduction_src
]
local_reduction = emit_sources_reduction(binary_fn, locally_reduced_sources, graph)
return local_reduction


def emit_global_reduction(
binary_fn: Callable,
src: fx.Node,
Expand All @@ -111,6 +165,9 @@ def emit_global_reduction(
cluster_size: int,
cluster_stride: int,
) -> fx.Node:
"""
Reduce data across threads in a warp by doing butterfly shuffle.
"""
init = src
num_steps = int(math.log2(float(cluster_size)))
for _ in range(num_steps):
Expand Down Expand Up @@ -189,12 +246,14 @@ def decompose_reduce_ops(
raise NotImplementedError(
"NYI: Expect all reduce_src to have same local reduce size."
)
src_reduction = emit_sources_reduction(
binary_fn, reduction_src, custom.graph
)
local_reduction = emit_local_reduction(
binary_fn, src_reduction, custom.graph, local_reduce_sizes[0]
)
if binary_fn == Maximum:
local_reduction = emit_scalarized_local_reduction(
binary_fn, reduction_src, custom.graph, local_reduce_sizes[0]
)
else:
local_reduction = emit_local_reduction(
binary_fn, reduction_src, custom.graph, local_reduce_sizes[0]
)

# Global Reduce
cluster_size, cluster_stride = determine_shuffle_config(
Expand Down
34 changes: 25 additions & 9 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,16 @@ def repeat(
print(dynamic_attention_pipelined(q, k, v, output).module_op)

# CHECK-LABEL: func.func @dynamic_attention_pipelined
# CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}}
# CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}}
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}}
# CHECK-COUNT-14: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-7: {{.*}} = amdgpu.mfma
# CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-10: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-16: vector.maskedstore {{.*}}


Expand Down Expand Up @@ -461,12 +462,17 @@ def repeat(

# CHECK-LABEL: func.func @base_attention_pipelined
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-14: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-7: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-10: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}}


@run_test
Expand Down Expand Up @@ -760,7 +766,17 @@ def repeat(
# CHECK-LABEL: func.func @base_attention_32x32x8
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}

# Test for reduction decomposition related to softmax.
# CHECK-NOT: arith.maximumf {{.*}}, {{.*}} : vector<16xf32>
# CHECK-COUNT-30: arith.maximumf {{.*}}, {{.*}} : vector<1xf32>
# CHECK: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: arith.maximumf {{.*}}, {{.*}} : vector<1xf32>
# CHECK: arith.addf {{.*}}, {{.*}} : vector<16xf32>
# CHECK-COUNT-14: arith.addf {{.*}}, {{.*}} : vector<1xf32>
# CHECK: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: arith.addf {{.*}}, {{.*}} : vector<1xf32>

# CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [0], sizes = [4], strides = [1]}
# CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [4], sizes = [4], strides = [1]}
# CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [8], sizes = [4], strides = [1]}
Expand Down

0 comments on commit 142c8a5

Please sign in to comment.