diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index ca80f7d4..b1fbc7ff 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -102,6 +102,10 @@ def exp2(src: "Register") -> "Register": ... +def log2(src: "Register") -> "Register": + ... + + def reciprocal(src: "Register") -> "Register": ... @@ -653,6 +657,7 @@ def infer_type(self): self.type = broadcasted_type +@define_interface_op("log2") @define_interface_op("exp2") @define_interface_op("reciprocal") @define_interface_op("abs") diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index d6507a41..a719329d 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -14,6 +14,8 @@ import torch.utils._pytree as pytree from collections import namedtuple +from .symbolic_constraints import SymbolicAlias + from ..compiler.ir import ( Attribute, DenseElementsAttr, @@ -55,6 +57,7 @@ read, reduction, exp2, + log2, reciprocal, abs, maximum, @@ -491,7 +494,7 @@ def handle_register(emitter: WaveEmitter, node: fx.Node): shape, dtype, value = node.args except ValueError as e: raise ValidationError("Malformed arguments") from e - get_thread_shape = lambda index: max(x.size for x in index.values()) + get_thread_shape = lambda index: max(subs_idxc(x.size) for x in index.values()) shape = [get_thread_shape(get_custom(node).index)] vector_shape = cast_py_literal(emitter, shape) element_type = IrType.parse(dtype.ir_type_asm()) @@ -1120,6 +1123,16 @@ def handle_exp2(source: Value) -> OpResult: return result +@handle_unary_op(log2) +def handle_log2(source: Value) -> OpResult: + element_type = get_type_or_element_type(source.type) + if _is_float_type(element_type): + result = math_d.log2(source) + else: + raise ValidationError(f"Found unhandled operand type for exp2: {element_type}") + return result + + @handle_unary_op(reciprocal) def handle_reciprocal(source: Value) -> OpResult: element_type = get_type_or_element_type(source.type) @@ -1351,6 +1364,8 @@ def handle_get_result(emitter: WaveEmitter, node: fx.Node): @handle_op(operator.getitem) def handle_getitem(emitter: WaveEmitter, node: fx.Node): + if not node.users: + return raise NotImplementedError("getitem: Currently only stub implementation") diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 346d7529..e3926433 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -9,6 +9,8 @@ from typing import Any, TypeAlias, Sequence, Type, Callable from functools import partial +from .symbolic_constraints import SymbolicAlias + from .constraints import ( Constraint, HardwareConstraint, @@ -17,6 +19,7 @@ ) from ..ops.wave_ops import ( Allocate, + BinaryPyOp, CustomOp, GetResult, Getitem, @@ -647,11 +650,19 @@ def get_dim_scaling( if len(hardware_constraints) != 1: raise ValueError("Exactly one hardware constraint must be provided") + aliased_dims: list[IndexSymbol] = [ + constraint.source + for constraint in constraints + if isinstance(constraint, SymbolicAlias) + ] + idxc = IndexingContext.current() for constraint in constraints: if isinstance(constraint, WorkgroupConstraint) or isinstance( constraint, TilingConstraint ): + if constraint.dim in aliased_dims: + continue hw_cons = hardware_constraints[0] tile_size = idxc.get_static_value(constraint.tile_size) if constraint.dim not in node.vector_shapes: @@ -740,7 +751,6 @@ def _handle_reduction_dim( ): # Rediscover iter args # TODO: Register iter args with the reduction initially so accessing them is easier - iter_args: list[CustomOp] = [] reduction_subgraph = trace.get_subgraph(reduction.subgraph_name) # TODO: Handle case where MMAs/ReduceOps do not have Output as direct consumer. diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index dbbd5330..0b49f68b 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -19,9 +19,15 @@ IterArg, CustomOp, Reshape, + Output, +) +from .constraints import ( + Constraint, + HardwareConstraint, + WorkgroupConstraint, ) -from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint from .assumptions import Assumption +from .symbolic_constraints import SymbolicAlias from .._support.tracing import CapturedTrace, IndexingContext from .._support.indexing import IndexSymbol, IndexSequence from ..lang.global_symbols import * @@ -34,7 +40,6 @@ get_inputs, get_users, get_largest_index_and_size, - capture_backward_slice, ) import torch.fx as fx import numpy as np @@ -57,6 +62,20 @@ def get_vector_shape( return vector_shapes +def _get_symbolic_shape_and_vector_shapes( + custom: CustomOp, + aliases: dict[IndexSymbol, SymbolicAlias], + hw_constraint: HardwareConstraint, +): + # When the memory type has symbolic aliases, use the memory type + # as it includes the aliased variables. + symbolic_shape = custom.register_type.symbolic_shape + vector_shapes = custom.vector_shapes + if any([x in custom.memory_type.symbolic_shape for x in aliases]): + symbolic_shape = custom.memory_type.symbolic_shape + return symbolic_shape, vector_shapes + + def partition_strided_operators(trace: CapturedTrace, constraints: list[Constraint]): """ This function analyzes the index sequence of operators in the graph @@ -92,6 +111,7 @@ def has_strided_access(node: fx.Node) -> bool: strided_operators = trace.walk(has_strided_access) hw_constraint = [c for c in constraints if isinstance(c, HardwareConstraint)][0] + aliases = {c.source: c for c in constraints if isinstance(c, SymbolicAlias)} for operator in strided_operators: custom = get_custom(operator) simplified_index = { @@ -99,9 +119,11 @@ def has_strided_access(node: fx.Node) -> bool: for dim in custom.index } - shape = get_vector_shape( - custom.vector_shapes, custom.register_type.symbolic_shape + symbolic_shape, vector_shapes = _get_symbolic_shape_and_vector_shapes( + custom, aliases, hw_constraint ) + + shape = get_vector_shape(vector_shapes, symbolic_shape) elements_per_thread = subs_idxc(custom.elements_per_thread) max_stride_dim, max_stride = max( [(dim, seq.stride) for dim, seq in simplified_index.items()], @@ -126,7 +148,7 @@ def has_strided_access(node: fx.Node) -> bool: dim: IndexSequence( simplified_index[dim].start.subs({GPR_NUM: 0}) + offset[j], 1, 1 ) - for j, dim in enumerate(custom.register_type.symbolic_shape) + for j, dim in enumerate(symbolic_shape) } ops_to_combine.append(write) @@ -312,12 +334,30 @@ def set_derived_index(trace): worklist.append((inp, new_index)) +def verify_nodes(trace: CapturedTrace): + """ + Verify that all the valid nodes have their index and vector shapes set. + """ + nodes = trace.walk(lambda x: x) + for node in nodes: + custom = get_custom(node) + if isinstance(custom, (Placeholder, Allocate)) and not isinstance( + custom, IterArg + ): + continue + if isinstance(custom, (Output, Reduction)): + continue + assert custom.index, f"Index not set for node {custom.fx_node}" + assert custom.vector_shapes, f"Vector shapes not set for node {custom.fx_node}" + + def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): mma_index = get_mma_dimensional_mapping(trace, get_hardware_constraint(constraints)) trace.walk(partial(set_thread_independent_index, constraints)) set_thread_dependent_index(constraints, mma_index, trace) set_derived_index(trace) resolve_thread_shapes(trace, constraints) + verify_nodes(trace) def compute_stride( @@ -417,8 +457,11 @@ def set_thread_independent_index( if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): return + hw_cons = get_hardware_constraint(constraints) constraints = [ - c for c in constraints if not isinstance(c, (HardwareConstraint, Assumption)) + c + for c in constraints + if not isinstance(c, (HardwareConstraint, Assumption, SymbolicAlias)) ] index = {} @@ -577,11 +620,21 @@ def should_update_index( source: CustomOp, source_index: dict[IndexSymbol, IndexSequence], source_vector_shapes: dict[IndexSymbol, int], + symbolic_constraints: list[SymbolicAlias], ): + # Get symbolic shape without any aliased variables. + aliased_dims = [x.source for x in symbolic_constraints] + symbolic_shape = set(source.type.symbolic_shape).difference(aliased_dims) + + # If all the source indexing dimensions are not present in source vector shapes, + # we should not update the index. + if not set(symbolic_shape).issubset(set(source_vector_shapes.keys())): + return False + # Determine if we should update the idx based on the source. # We update the source only if the source index provides # information about all the non-batch dimensions of the source. - non_batch_dims = [x for x in source.indexing_dims if source_vector_shapes[x] > 1] + non_batch_dims = [x for x in symbolic_shape if source_vector_shapes[x] > 1] # If the source index is smaller than the non-batch dims, check if the # source index is a subset of the non-batch dims. @@ -595,12 +648,28 @@ def should_update_index( return True +def append_aliased_shapes(source: CustomOp, symbolic_constraints: list[SymbolicAlias]): + """ + Append the aliased shapes to the vector shapes of the source, if they + are present in the source index. + """ + for constraint in symbolic_constraints: + if ( + constraint.target in source.vector_shapes + and constraint.source in source.index + ): + source.vector_shapes[constraint.source] = constraint.apply( + source.vector_shapes[constraint.target] + ) + + def propagate_index( node: CustomOp, hardware_constraint: HardwareConstraint, workgroup_constraints: list[WorkgroupConstraint], mma_index: dict[MMA, dict[IndexSymbol, int]], visited: set[CustomOp], + symbolic_constraints: list[SymbolicAlias], ): """ Propagate the index and vector shapes through the graph @@ -619,11 +688,14 @@ def propagate_index( if source in visited: continue if not isinstance(source, (Reduction, MMA)): - if not should_update_index(source, source_index, source_vector_shapes): + if not should_update_index( + source, source_index, source_vector_shapes, symbolic_constraints + ): continue source_index = source.transform_index(source_index) source.index = combine_indices(source.index, source_index) source.vector_shapes = source_vector_shapes + append_aliased_shapes(source, symbolic_constraints) visited.add(source) for func in [get_inputs, get_users]: sources, reduction = add_nodes_to_sources( @@ -656,11 +728,17 @@ def set_thread_dependent_index( workgroup_constraints = [ c for c in constraints if isinstance(c, WorkgroupConstraint) ] + symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicAlias)] for source in sources: visited = visited.union(set([x for x in sources])) visited.remove(source) visited = propagate_index( - source, hardware_constraint, workgroup_constraints, mma_index, visited + source, + hardware_constraint, + workgroup_constraints, + mma_index, + visited, + symbolic_constraints, ) @@ -696,7 +774,7 @@ def create_broadcast( binary_op.graph ) custom = get_custom(broadcasted) - custom.vector_shapes = to_broadcast.vector_shapes + custom.vector_shapes = binary_op.vector_shapes custom.index = deepcopy(target_node.index) custom.index[broadcast_dim].size = broadcast_size broadcast_idx = list(binary_op.node_args.values()).index(to_broadcast) @@ -712,14 +790,21 @@ def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]): Currently, the only mismatches that can be resolved are when one of the shapes is 1 and the other is > 1. """ + + def get_index(custom: CustomOp): + if isinstance(custom, MMA): + return custom.acc.index + return custom.index + binary_ops = trace.walk(lambda node: isinstance(get_custom(node), BinaryPyOp)) for binary_op in binary_ops: custom = get_custom(binary_op) # Get the largest dim and shape from the lhs and rhs. lhs = get_custom(custom.lhs) rhs = get_custom(custom.rhs) - lhs_dim, lhs_size = get_largest_index_and_size(lhs.index) - rhs_dim, rhs_size = get_largest_index_and_size(rhs.index) + + lhs_dim, lhs_size = get_largest_index_and_size(get_index(lhs)) + rhs_dim, rhs_size = get_largest_index_and_size(get_index(rhs)) # If they are equal we are done. if lhs_dim == rhs_dim and lhs_size == rhs_size: @@ -737,17 +822,17 @@ def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]): to_broadcast = rhs if broadcast_rhs else lhs broadcast_dim = lhs_dim if broadcast_rhs else rhs_dim broadcast_size = lhs_size if broadcast_rhs else rhs_size - target = lhs if broadcast_rhs else rhs + broadcasted = lhs if broadcast_rhs else rhs if lhs_dim != rhs_dim: # If the dimensions don't agree, we can still do this broadcast only if # the two nodes differ in shape along the broadcasting dimension and the # broadcasting dimension is the innermost dimension. - missing_dims = set(target.indexing_dims).difference( - set(to_broadcast.indexing_dims) + missing_dims = set(broadcasted.type.symbolic_shape).difference( + set(to_broadcast.type.symbolic_shape) ) is_only_missing_dim = missing_dims == {broadcast_dim} - is_innermost_dim = broadcast_dim == target.indexing_dims[-1] + is_innermost_dim = broadcast_dim == broadcasted.type.symbolic_shape[-1] if not is_only_missing_dim and not is_innermost_dim: raise NotImplementedError( @@ -755,4 +840,6 @@ def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]): ) # Broadcast - create_broadcast(custom, to_broadcast, broadcast_dim, broadcast_size, target) + create_broadcast( + custom, to_broadcast, broadcast_dim, broadcast_size, broadcasted + ) diff --git a/iree/turbine/kernel/wave/symbolic_constraints.py b/iree/turbine/kernel/wave/symbolic_constraints.py new file mode 100644 index 00000000..f0f97673 --- /dev/null +++ b/iree/turbine/kernel/wave/symbolic_constraints.py @@ -0,0 +1,70 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.turbine.kernel._support.indexing import IndexExpr, IndexSymbol +from dataclasses import dataclass +from typing import Callable +from .utils import subs_idxc +from .constraints import ( + Constraint, + WorkgroupConstraint, + WaveConstraint, + TilingConstraint, +) + + +@dataclass +class SymbolicAlias: + """ + A constraint of the form `tkw.SymbolicConstraint(K, SYMBOLIC_K)` specifies + that the relationship between the source and target symbols is given by + source = source_to_target(target). + + SymbolicAliases are modeled in the compiler as additional workgroup, wave, + and tiling constraints that are derived from the source. They are ignored + during expansion and utilize the same workgroup and wave ids as the + target symbol. + """ + + source: IndexSymbol | IndexExpr + target: IndexSymbol | IndexExpr + source_to_target: Callable[[IndexSymbol | IndexExpr], IndexSymbol | IndexExpr] + + def apply(self, target: IndexSymbol | IndexExpr) -> IndexSymbol | IndexExpr: + return subs_idxc(self.source_to_target(target)) + + def create_new_constraints(self, constraints: list[Constraint]) -> list[Constraint]: + """ + Creates new constraints for the given constraints with the appropriate + substitution of the indexing context. + + """ + new_constraints = [] + if not constraints: + return new_constraints + match constraints[0]: + case WorkgroupConstraint(): + build_constraint = lambda x, y, z: WorkgroupConstraint(x, y, z) + id_fn = lambda x: x.workgroup_dim + case WaveConstraint(): + build_constraint = lambda x, y, z: WaveConstraint(x, y, z) + id_fn = lambda x: x.wave_id + case TilingConstraint(): + build_constraint = lambda x, y, z: TilingConstraint(x, y, z) + id_fn = lambda x: x.induction_var + for constraint in constraints: + if self.target == constraint.dim: + tile_size = self.apply(constraint.tile_size) + if tile_size.is_number and tile_size == 0: + continue + new_constraints.append( + build_constraint( + self.source, + self.apply(constraint.tile_size), + id_fn(constraint), + ) + ) + return new_constraints diff --git a/iree/turbine/kernel/wave/templates/decode_attention.py b/iree/turbine/kernel/wave/templates/decode_attention.py new file mode 100644 index 00000000..3cfd4551 --- /dev/null +++ b/iree/turbine/kernel/wave/templates/decode_attention.py @@ -0,0 +1,201 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel._support.dtype import DataType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from ..symbolic_constraints import SymbolicAlias +import sympy +from enum import Enum +import math + + +def get_decode_attention_kernels( + shape: tuple[int], + mfma_variant: MMAType, +): + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + U = tkl.sym.U + BLOCK_U = tkl.sym.BLOCK_U + + class Phase(Enum): + PHASE_0 = (0,) + PHASE_1 = (1,) + + M_WAVES = 2 + N_WAVES = 2 + K_WAVES = 2 + THREADS_PER_WAVE = 64 + PHASE_1_BLOCK_M = 128 + PHASE_1_ELEMS_PER_THREAD = PHASE_1_BLOCK_M // THREADS_PER_WAVE + PHASE_1_BLOCK_N = 1 + + def phase_0_constraints(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / M_WAVES)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / N_WAVES)] + constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 2)] + constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / K_WAVES)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 3)] + constraints += [ + SymbolicAlias(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES))) + ] + vector_shapes = {B: 0} + waves_per_block = (M_WAVES, N_WAVES, K_WAVES) + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=THREADS_PER_WAVE, + waves_per_block=waves_per_block, + mma_type=mfma_variant, + vector_shapes=vector_shapes, + ) + ] + return constraints + + def phase_1_constraints() -> list[tkw.Constraint]: + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(U, BLOCK_U)] + vector_shapes = { + B: 0, + M: BLOCK_M, + N: BLOCK_N, + U: 1, + } + waves_per_block = (1, 1, 1) + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=THREADS_PER_WAVE, + waves_per_block=waves_per_block, + mma_type=mfma_variant, + vector_shapes=vector_shapes, + ) + ] + return constraints + + def get_constraints(phase: Phase) -> list[tkw.Constraint]: + if phase == Phase.PHASE_0: + return phase_0_constraints() + else: + return phase_1_constraints() + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, + inputs={B: i, N: j, M: k}, + outputs={B: i, N: j, M: k}, + ) + + @tkw.wave(get_constraints(Phase.PHASE_0)) + def phase_0( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + output: tkl.Memory[U, B, N, M, GLOBAL_ADDRESS_SPACE, tkl.f32], + output_max: tkl.Memory[U, B, M, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + new_acc = tkl.Register[B, N, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(k_reg, q_reg, c_reg) + x_j = tkw.permute(acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, init_max, dim=K2) + e_delta_max = tkw.exp2(init_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = init_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(v_reg, imm_f16, new_acc) + res = acc / d_j + dm_j = m_j + tkw.log2(d_j) + tkw.write(dm_j, output_max, elements_per_thread=1) + tkw.write(res, output, elements_per_thread=STORE_ELEMS_PER_THREAD) + + @tkw.wave(get_constraints(Phase.PHASE_1)) + def phase_1( + logits: tkl.Memory[U, B, N, M, GLOBAL_ADDRESS_SPACE, tkl.f32], + logits_max: tkl.Memory[U, B, M, GLOBAL_ADDRESS_SPACE, tkl.f32], + output: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + @tkw.reduction(U, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + x_j = tkw.read(logits, elements_per_thread=PHASE_1_ELEMS_PER_THREAD) + xm_j = tkw.read(logits_max, elements_per_thread=PHASE_1_ELEMS_PER_THREAD) + m_j = tkw.maximum(xm_j, partial_max) + old_scale = tkw.exp2(partial_max - m_j) + new_scale = tkw.exp2(xm_j - m_j) + d_j = partial_sum * old_scale + new_scale + new_acc = acc * old_scale + term = new_scale * x_j + new_acc = new_acc + term + return m_j, d_j, new_acc + + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write( + res, output, mapping=mapping, elements_per_thread=PHASE_1_ELEMS_PER_THREAD + ) + + symbols_0 = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + BLOCK_U: 1, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + } + symbols_0[U] = math.ceil(symbols_0[K2] / (symbols_0[BLOCK_K2] / K_WAVES)) + symbols_1 = dict(symbols_0) + symbols_1[BLOCK_M] = PHASE_1_BLOCK_M + symbols_1[BLOCK_N] = PHASE_1_BLOCK_N + + return phase_0, phase_1, symbols_0, symbols_1 diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index b3e695ef..5fb4ff30 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -871,7 +871,8 @@ def bfs( def capture_forward_slice( - node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True + node: fx.Node, + filter_fn: Callable[[fx.node], bool] = lambda x: True, ) -> set[fx.Node]: """ Run BFS on the graph to capture the forward slice of a node. diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 1cd52ef4..51b1d051 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -8,6 +8,8 @@ import torch.fx as fx import inspect +from .symbolic_constraints import SymbolicAlias + from ..compiler import builder, dispatch_codegen, kernel_codegen, host_codegen from ..compiler.ir import Context, Operation from .codegen import WaveEmitter @@ -132,6 +134,14 @@ def hardware_constraints(self) -> list[HardwareConstraint]: if isinstance(constraint, HardwareConstraint) ] + @property + def symbolic_constraints(self) -> list[HardwareConstraint]: + return [ + constraint + for constraint in self.constraints + if isinstance(constraint, SymbolicAlias) + ] + def _trace(self) -> CapturedTrace: region_graph = KernelRegionGraph() with CompiledContext(region_graph, grid_type=self.grid_type) as context: @@ -212,13 +222,43 @@ def initialize_reductions(self, trace: CapturedTrace) -> None: if tiling_constraint.dim == get_custom(reduction).axis: reduction.count = subs_idxc(tiling_constraint.count) + def get_workgroup_dims(self) -> list[int]: + """ + Returns the workgroup dimensions that are not aliased. + """ + # Ignore aliased variables. They will be handled separately. + aliased_dims = [ + x.source for x in self.constraints if isinstance(x, SymbolicAlias) + ] + workgroup_dims = { + x.workgroup_dim: x + for x in self.workgroup_constraints + if x.dim not in aliased_dims + } + return workgroup_dims + + def update_aliased_workgroup_constraints( + self, workgroup_dims: dict[int, int] + ) -> None: + """ + This function updates the wg_dim for aliased workgroup constraints. + """ + aliased_dims = [ + x.source for x in self.constraints if isinstance(x, SymbolicAlias) + ] + # Update the workgroup constraints for aliases sources. + for constraint in self.workgroup_constraints: + if constraint.dim in aliased_dims: + constraint.wg_dim = workgroup_dims[constraint.workgroup_dim].wg_dim + def initialize_workgroup_constraints(self, trace: CapturedTrace) -> None: """ For kernels that distribute more than three dimensions among workgroups, we need to update the workgroup constraints for dimensions >= 2 with the appropriate workgroup index. """ - workgroup_dims = {x.workgroup_dim: x for x in self.workgroup_constraints} + + workgroup_dims = self.get_workgroup_dims() if all(x <= 2 for x in workgroup_dims.keys()): return shape = [ @@ -228,6 +268,42 @@ def initialize_workgroup_constraints(self, trace: CapturedTrace) -> None: new_workgroup_dims = delinearize_index(WORKGROUP_2, shape) for i in range(2, max(workgroup_dims.keys()) + 1): workgroup_dims[i].wg_dim = new_workgroup_dims[i - 2] + self.update_aliased_workgroup_constraints(workgroup_dims) + + def initialize_symbolic_constraints(self, trace: CapturedTrace) -> None: + """ + For each symbolic constraint, create new constraints for the + related symbolic values with appropriate substitutions. + """ + new_wg_constraints, new_wave_constraints, new_tiling_constraints = [], [], [] + for symbolic_constraint in self.symbolic_constraints: + new_wg_constraints += symbolic_constraint.create_new_constraints( + self.workgroup_constraints + ) + new_wave_constraints += symbolic_constraint.create_new_constraints( + self.wave_constraints + ) + new_tiling_constraints += symbolic_constraint.create_new_constraints( + self.tiling_constraints + ) + # Remove wave constraints with same tile size as workgroup constraints + for wave_constraint in new_wave_constraints: + for workgroup_constraint in new_wg_constraints: + if ( + wave_constraint.dim == workgroup_constraint.dim + and wave_constraint.tile_size == workgroup_constraint.tile_size + ): + new_wave_constraints.remove(wave_constraint) + self.constraints += ( + new_wg_constraints + new_wave_constraints + new_tiling_constraints + ) + idxc = IndexingContext.current() + for constraint in self.symbolic_constraints: + if subs_idxc(constraint.target).is_number: + idxc._bind_symbol( + constraint.source, + subs_idxc(constraint.source_to_target(constraint.target)), + ) def _trace_and_get_kernel_signature( self, @@ -242,6 +318,7 @@ def _trace_and_get_kernel_signature( self.create_induction_vars(graph) self.initialize_wave_constraints(graph) self.initialize_reductions(graph) + self.initialize_symbolic_constraints(graph) self.initialize_workgroup_constraints(graph) idxc = IndexingContext.current() @@ -307,7 +384,10 @@ def _trace_and_get_kernel_signature( # Determine grid shape. self.grid_type.dims = [1, 1, 1] max_workgroup_dim = 2 + aliases = [x.source for x in self.constraints if isinstance(x, SymbolicAlias)] for constraint in self.workgroup_constraints: + if constraint.dim in aliases: + continue dim = ( constraint.workgroup_dim if constraint.workgroup_dim < max_workgroup_dim diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index f6271d38..a51cbe46 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -6,13 +6,18 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.symbolic_constraints import SymbolicAlias from iree.turbine.kernel.wave.utils import ( run_test, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) +from iree.turbine.kernel.wave.templates.decode_attention import ( + get_decode_attention_kernels, +) import torch -from enum import Enum +import sympy +import math # Input sizes B = tkl.sym.B @@ -479,174 +484,70 @@ def repeat( def test_flash_decoding(): shape = (8, 128, 128, 64, 256) mfma_variant = tkw.MMAType.F32_16x16x16_F16 - - class Phase(Enum): - QK = (0,) - SOFTMAX_V = (1,) - - def get_constraints(phase: Phase) -> list[tkw.Constraint]: - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - if phase == Phase.QK: - constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] - constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / 2)] - vector_shapes = {B: 0, M: 16, K2: 16} - else: - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - vector_shapes = {B: 0, M: 16, N: 16} - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes=vector_shapes, - ) - ] - return constraints - - # The first kernel computes Q @ K.T. - @tkw.wave(get_constraints(Phase.QK)) - def qk_kernel( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) - q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - acc = tkw.mma(k_reg, q_reg, c_reg) - x_j = tkw.permute(acc, target_shape=[B, M, K2]) - tkw.write(x_j, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - # The second kernel computes the softmax and V @ softmax(Q @ K.T). - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + phase_0, phase_1, hyperparams_0, hyperparams_1 = get_decode_attention_kernels( + shape, mfma_variant ) - @tkw.wave(get_constraints(Phase.SOFTMAX_V)) - def softmax_v_kernel( - qk: tkl.Memory[B, M, K2, ADDRESS_SPACE, tkl.f32], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - BLOCK_B: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - B: shape[0], - M: shape[1], - N: shape[2], - K1: shape[3], - K2: shape[4], - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - torch.manual_seed(0) q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - qkt = torch.zeros(shape[0], shape[1], shape[4], dtype=torch.float32) + logits = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + logits_max = torch.zeros(shape[0], shape[1], dtype=torch.float32) output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) with tk.gen.TestLaunchContext( - hyperparams, + hyperparams_0, canonicalize=True, run=False, run_bench=False, schedule=False, use_scheduling_barriers=False, ): - print(qk_kernel(q, k, qkt).module_op) - - # CHECK: func.func @qk_kernel - # CHECK-NOT: {{.*}} = scf.for - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-8: {{.*}} = vector.load - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: vector.store + print(phase_0(q, k, v, logits, logits_max).module_op) + + # CHECK: func.func @phase_0 + # CHECK-NOT: {{.*}} = scf.for + # CHECK-COUNT-9: {{.*}} = vector.load + # CHECK-COUNT-1: vector.store + # CHECK-COUNT-4: {{.*}} = vector.load + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-2: {{.*}} = arith.subf + # CHECK-COUNT-2: {{.*}} = math.exp2 + # CHECK-COUNT-2: {{.*}} = arith.subf + # CHECK-COUNT-2: {{.*}} = math.exp2 + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK-COUNT-2: {{.*}} = arith.divf + # CHECK-COUNT-2: {{.*}} = math.log2 + # CHECK-COUNT-18: vector.store with tk.gen.TestLaunchContext( - hyperparams, + hyperparams_1, canonicalize=True, run=False, run_bench=False, schedule=False, use_scheduling_barriers=False, ): - print(softmax_v_kernel(qkt, v, output).module_op) - - # CHECK: func.func @softmax_v_kernel - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-4: {{.*}} = gpu.shuffle - # CHECK-COUNT-4: {{.*}} = arith.subf - # CHECK-COUNT-4: {{.*}} = math.exp2 - # CHECK-COUNT-4: {{.*}} = gpu.shuffle - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + print(phase_1(logits, logits_max, output).module_op) + + # CHECK: func.func @phase_1 + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-2: {{.*}} = vector.load + # CHECK-COUNT-1: {{.*}} = arith.maximumf + # CHECK-COUNT-1: {{.*}} = arith.subf + # CHECK-COUNT-1: {{.*}} = math.exp2 + # CHECK-COUNT-1: {{.*}} = arith.subf + # CHECK-COUNT-1: {{.*}} = math.exp2 + # CHECK-COUNT-1: {{.*}} = arith.mulf + # CHECK-COUNT-1: {{.*}} = arith.addf + # CHECK-COUNT-2: {{.*}} = arith.mulf + # CHECK-COUNT-1: {{.*}} = arith.addf + # CHECK-COUNT-1: {{.*}} = arith.divf + # TODO: Remove vector.scatter when optimizing for performance + # CHECK-COUNT-1: vector.scatter @run_test diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 3c8a57df..5a687782 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -25,6 +25,9 @@ device_zeros, ) from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.decode_attention import ( + get_decode_attention_kernels, +) import os import json from torch.testing import assert_close, assert_allclose @@ -930,136 +933,11 @@ def testFlashDecoding( ): run_bench = request.config.getoption("--runperf") dump_perf = request.config.getoption("--dump-perf-files-path") - # Input sizes - B = tkl.sym.B - M = tkl.sym.M - N = tkl.sym.N - K1 = tkl.sym.K1 - K2 = tkl.sym.K2 - # Workgroup tile sizes - BLOCK_B = tkl.sym.BLOCK_B - BLOCK_M = tkl.sym.BLOCK_M - BLOCK_N = tkl.sym.BLOCK_N - BLOCK_K2 = tkl.sym.BLOCK_K2 - # Address space (for GPU, shared(1) or global(0)) - ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE - # Other hyperparameters - LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD - STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD - - class Phase(Enum): - QK = (0,) - SOFTMAX_V = (1,) - - def get_constraints(phase: Phase) -> list[tkw.Constraint]: - if mfma_variant == MMAType.F32_16x16x16_F16: - Mvec = 16 - Nvec = 16 - if mfma_variant == MMAType.F32_32x32x8_F16: - Mvec = 32 - Nvec = 32 - ratio_m = 2 - ratio_n = 2 - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)] - if phase == Phase.QK: - constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] - constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / ratio_n)] - vector_shapes = {B: 0, M: Mvec, K2: Nvec} - else: - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / ratio_n)] - vector_shapes = {B: 0, M: Mvec, N: Nvec} - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(ratio_m, ratio_n, 1), - mma_type=mfma_variant, - vector_shapes=vector_shapes, - ) - ] - if dynamic_dims: - constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] - return constraints - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + phase_0, phase_1, hyperparams_0, hyperparams_1 = get_decode_attention_kernels( + shape, mfma_variant ) - - # The first kernel computes K @ Q.T. - @tkw.wave(get_constraints(Phase.QK)) - def qk_kernel( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) - q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - acc = tkw.mma(k_reg, q_reg, c_reg) - x_j = tkw.permute(acc, target_shape=[B, M, K2]) - tkw.write(x_j, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - # The second kernel computes the softmax and V @ softmax(K @ Q.T). - @tkw.wave(get_constraints(Phase.SOFTMAX_V)) - def softmax_v_kernel( - qk: tkl.Memory[B, M, K2, ADDRESS_SPACE, tkl.f32], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - BLOCK_B: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 64, - B: shape[0], - M: shape[1], - N: shape[2], - K1: shape[3], - K2: shape[4], - } - hyperparams.update(get_default_scheduling_params()) + hyperparams_0.update(get_default_scheduling_params()) + hyperparams_1.update(get_default_scheduling_params()) config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 @@ -1070,68 +948,57 @@ def repeat( dump_perf, "tk_" + perf_filename ) - dynamic_symbols = [] - dynamic_symbols_map = {} - if dynamic_dims: - dynamic_symbols_map[M] = hyperparams[M] - dynamic_symbols_map[N] = hyperparams[N] - dynamic_symbols_map[B] = hyperparams[B] - dynamic_symbols_map[K2] = hyperparams[K2] - dynamic_symbols.append(M) - dynamic_symbols.append(N) - dynamic_symbols.append(B) - dynamic_symbols.append(K2) - del hyperparams[M] - del hyperparams[N] - del hyperparams[B] - del hyperparams[K2] - torch.manual_seed(0) - q = device_randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = device_randn(shape[0], shape[4], shape[3], dtype=torch.float16) - qk = device_zeros(shape[0], shape[1], shape[4], dtype=torch.float32) - v = device_randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + B, M, N, K1, K2 = shape + U = hyperparams_0[index_symbol("U")] + q = device_randn(B, M, K1, dtype=torch.float16) + k = device_randn(B, K2, K1, dtype=torch.float16) + v = device_randn(B, K2, N, dtype=torch.float16) + phase_0_output = device_zeros(U, B, N, M, dtype=torch.float32) + phase_0_output_max = device_zeros(U, B, M, dtype=torch.float32) + output = device_zeros(B, M, N, dtype=torch.float32) log2e = 1.44269504089 - dk_sqrt = math.sqrt(1.0 / shape[3]) + dk_sqrt = math.sqrt(1.0 / K1) with tk.gen.TestLaunchContext( - hyperparams, + hyperparams_0, canonicalize=True, run=True, run_bench=run_bench, run_config=config, schedule=enable_scheduling, use_scheduling_barriers=enable_scheduling_barriers, - dynamic_symbols=dynamic_symbols, - dynamic_symbols_map=dynamic_symbols_map, ): # TODO: Add scaling of QK as part of kernel. - mb_qk = qk_kernel(q * dk_sqrt * log2e, k, qk) + mb_qk = phase_0( + q * dk_sqrt * log2e, + k, + v.permute([0, 2, 1]), + phase_0_output, + phase_0_output_max, + ) with tk.gen.TestLaunchContext( - hyperparams, + hyperparams_1, canonicalize=True, run=True, run_bench=run_bench, run_config=config, schedule=enable_scheduling, use_scheduling_barriers=enable_scheduling_barriers, - dynamic_symbols=dynamic_symbols, - dynamic_symbols_map=dynamic_symbols_map, ): # TODO: Add variant of non-transposed V attention kernel. - mb_sv = softmax_v_kernel(qk, v.permute([0, 2, 1]), output) + mb_sv = phase_1(phase_0_output, phase_0_output_max, output) torch_ref = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None ) if test_dump_generated_mlir: - filename = f"wave_qk_kernel_{'x'.join(map(str, shape))}.mlir" + filename = f"wave_phase_0_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: f.write(mb_qk.module_op.get_asm()) - filename = f"wave_softmax_v_kernel_{'x'.join(map(str, shape))}.mlir" + filename = f"wave_phase_1_kernel_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: f.write(mb_sv.module_op.get_asm())