diff --git a/iree/turbine/kernel/compiler/kernel_codegen.py b/iree/turbine/kernel/compiler/kernel_codegen.py index 5574dcba..3c5f4982 100644 --- a/iree/turbine/kernel/compiler/kernel_codegen.py +++ b/iree/turbine/kernel/compiler/kernel_codegen.py @@ -129,9 +129,27 @@ def sym_to_dim_asm(s: IndexSymbol) -> str: else: # Unranked. Not well supported, but for completeness. spec_asm = element_type_asm - strides = strides_from_symbolic_shape( - idx_context, kb_t.symbolic_shape, allow_mixed_shapes=True - ) + # If strides have been specified in the type, that implies that they are + # not consistent with the dimensions of the tensor, so we default to + # dynamic dims for all shapes. + ref_type = self.reference[1].type + if ref_type.physical_layout: + # Strides are always present in the physical layout. + strides = [ + idx_context.get_static_value(s) + for s in ref_type.physical_layout["stride"] + ] + # Shapes are not always present in the physical layout. + if ref_type.physical_layout.get("shape", None): + shape_asm = "x".join( + sym_to_dim_asm(s) for s in ref_type.physical_layout["shape"] + ) + spec_asm = f"{shape_asm}x{element_type_asm}" + else: + strides = strides_from_symbolic_shape( + idx_context, kb_t.symbolic_shape, allow_mixed_shapes=True + ) + if strides is None: memref_asm = f"memref<{spec_asm}>" elif _is_symbolic(strides): diff --git a/iree/turbine/kernel/lang/kernel_buffer.py b/iree/turbine/kernel/lang/kernel_buffer.py index 4ab00ea7..6ebe4ee1 100644 --- a/iree/turbine/kernel/lang/kernel_buffer.py +++ b/iree/turbine/kernel/lang/kernel_buffer.py @@ -64,6 +64,7 @@ def new_subtype( address_space: AddressSpace | NotSetType = NotSet, symbolic_shape: tuple[IndexExpr, ...] | NotSetType = NotSet, dtype: DataType | NotSetType = NotSet, + physical_layout: dict[str, IndexExpr] | NotSetType = NotSet, usage: KernelBufferUsage | NotSetType = NotSet, ) -> Type[SubtypeT]: init_address_space = ( @@ -71,6 +72,7 @@ def new_subtype( ) init_symbolic_shape = symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape # type: ignore init_dtype = dtype if dtype is not NotSet else cls.dtype # type: ignore + init_physical_layout = physical_layout if physical_layout else None # type: ignore init_usage = usage if usage is not NotSet else cls.usage # type: ignore class SubType(cls): @@ -78,6 +80,7 @@ class SubType(cls): symbolic_shape = init_symbolic_shape rank = len(init_symbolic_shape) # type: ignore dtype = init_dtype + physical_layout = init_physical_layout usage = init_usage if name is not NotSet: @@ -104,6 +107,7 @@ class KernelBuffer(metaclass=KernelBufferMeta): symbolic_shape: ClassVar[tuple[IndexExpr, ...]] rank: ClassVar[int] dtype: ClassVar[DataType] + stride: ClassVar[tuple[IndexExpr, ...]] def __init__(self, tensor: torch.Tensor): assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" diff --git a/iree/turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py index f87a9570..2598a1cc 100644 --- a/iree/turbine/kernel/lang/wave_types.py +++ b/iree/turbine/kernel/lang/wave_types.py @@ -14,7 +14,7 @@ from .._support.dtype import DataType from .._support.indexing import IndexExpr, IndexSymbol, index_symbol -from sympy import Symbol +from sympy import Symbol, Integer from sympy.core.expr import Expr from typing_extensions import Self @@ -41,6 +41,7 @@ class Memory(metaclass=KernelBufferMeta): symbolic_shape: ClassVar[tuple[IndexExpr, ...]] rank: ClassVar[int] dtype: ClassVar[DataType] + physical_layout: ClassVar[Optional[dict[str, IndexExpr]]] usage: ClassVar[Optional[KernelBufferUsage]] def __init__(self) -> None: @@ -55,9 +56,15 @@ def __class_getitem__( shift = 0 usage = KernelBufferUsage.NONE - if isinstance(shape_and_dtype[-1], KernelBufferUsage): - shift = 1 - usage = shape_and_dtype[-1] + last_dim = -1 + if isinstance(shape_and_dtype[last_dim], KernelBufferUsage): + shift += 1 + usage = shape_and_dtype[last_dim] + last_dim -= 1 + physical_layout = None + if isinstance(shape_and_dtype[last_dim], dict): + shift += 1 + physical_layout = shape_and_dtype[last_dim] shape = shape_and_dtype[: -2 - shift] addressSpace = shape_and_dtype[-2 - shift] dtype = shape_and_dtype[-1 - shift] @@ -85,6 +92,7 @@ def __class_getitem__( address_space=addressSpace, symbolic_shape=shape, dtype=dtype, + physical_layout=physical_layout, usage=usage, ) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index ffa73618..04ddea81 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -70,6 +70,7 @@ def read( elements_per_thread: Optional[IndexExpr | int] = None, mapping: Optional[IndexMapping] = None, mapping_dynamic_vals: "Register" | tuple["Register", ...] = (), + drop_dims: Optional[IndexExpr] = (), ) -> "Register": ... @@ -461,7 +462,7 @@ def node_args(self) -> dict[int, Any]: for i, arg in enumerate(self.fx_node.args): if isinstance(arg, fx.Node): custom_args[i] = get_custom(arg) - if isinstance(arg, list) and all(isinstance(x, fx.Node) for x in arg): + if isinstance(arg, Sequence) and all(isinstance(x, fx.Node) for x in arg): custom_args[i] = [get_custom(x) for x in arg] return custom_args @@ -965,6 +966,22 @@ class Read(CustomOp): mapping_dynamic_vals: tuple[fx.Node, ...] = () _write_dependency: Optional[list[fx.Node]] = None + """ + Note on drop_dims. + Consider the following loop: + + for b in range(B): + for k1 in range(K1): + for k2 in range(K2): + out[b, k1, k2] = in[b, 0, k1, k2] + + This is a slice where the output is a 3D tensor and the input is a 4D tensor. + The index mapping does not allow rank-reducing operations, since every symbol in the output must be + bound to an index variable. So we introduce a drop_dims field to specify which dimensions are dropped + after the mapping. + + """ + @property def indexing_dims(self) -> list[IndexSymbol]: if self.mapping is not None: @@ -1013,7 +1030,10 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + return { + k: v.apply_expr(subs[k], mapping[k]) if k in mapping else v + for k, v in index.items() + } return index @@ -1253,7 +1273,10 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + return { + k: v.apply_expr(subs[k], mapping[k]) if k in mapping else v + for k, v in index.items() + } return index diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index d0236ebb..00ca5d3f 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -55,8 +55,6 @@ read, reduction, exp2, - reciprocal, - abs, maximum, get_custom, get_result, @@ -179,20 +177,6 @@ def bind_node_proxies(self, node: fx.Node, proxies: List[IRProxyValue]): ) self._node_values[node] = proxies - def get_induction_vars_and_syms(self) -> tuple[list[OpResult], list[IndexExpr]]: - induction_var_syms = [] - induction_vars = [] - if self.induction_vars: - for constraint in self.constraints: - if isinstance(constraint, TilingConstraint): - assert ( - constraint.dim in self.induction_vars - ), f"Could not find induction var for {constraint.dim} dimension" - induction_var_syms.append(constraint.induction_var) - induction_vars.append(self.induction_vars[constraint.dim]) - - return induction_vars, induction_var_syms - def get_type_or_element_type(operand_type: IrType): assert isinstance(operand_type, IrType) @@ -205,7 +189,16 @@ def get_type_or_element_type(operand_type: IrType): def add_emitter_subs( emitter: WaveEmitter, dynamic_values: dict[IndexExpr, Value] = {} ) -> dict[IndexSymbol, Value]: - induction_vars, induction_var_syms = emitter.get_induction_vars_and_syms() + induction_var_syms = [] + induction_vars = [] + if emitter.induction_vars: + for constraint in emitter.constraints: + if isinstance(constraint, TilingConstraint): + assert ( + constraint.dim in emitter.induction_vars + ), f"Could not find induction var for {constraint.dim} dimension" + induction_var_syms.append(constraint.induction_var) + induction_vars.append(emitter.induction_vars[constraint.dim]) # TODO: factor this out all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars @@ -548,20 +541,6 @@ def _build_start_indices( ] -def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]): - """ - This function takes in indices of a Node, extract their sizes - into a list, and then try do an argmax on it. In the case where - there are multipled max_vals we pick the fastest/most minor one. - """ - - index_sizes = [i.size for i in indices.values()] - # Find the maximum value - max_size = max(index_sizes) - # Find the fastest/most minor index of the maximum value. - return max(i for i, size in enumerate(index_sizes) if size == max_size) - - def _compute_offset(indices: list[IndexExpr], strides: list[IndexExpr]) -> IndexExpr: return sum(i * s for i, s in zip(indices, strides)) @@ -570,6 +549,23 @@ def _get_symbolic_shape(node: fx.Node) -> tuple[IndexExpr]: return get_custom(node).type.symbolic_shape +def _is_identity_mapping( + mapping: IndexMapping, + input_shape: Optional[tuple[IndexExpr]] = None, + output_shape: Optional[tuple[IndexExpr]] = None, +) -> bool: + if not mapping.is_identity(): + return False + + if input_shape is not None and mapping.input_shape != input_shape: + return False + + if output_shape is not None and mapping.output_shape != output_shape: + return False + + return True + + def _build_mask( emitter: WaveEmitter, index: Dict[IndexExpr, IndexExpr], elements_per_thread: int ) -> Optional[OpResult]: @@ -578,8 +574,7 @@ def _build_mask( return None idxc = IndexingContext.current() - fastest_dim = _get_fastest_index(index) - last_dim = list(index)[fastest_dim] + last_dim = tuple(index.keys())[-1] new_index = {k: _get_start_index(v) for k, v in index.items()} new_index[last_dim] = new_index[last_dim] + idxc.iota(elements_per_thread) @@ -604,7 +599,6 @@ def _construct_gather_scatter_indices( elements_per_thread: int, is_read: bool, dynamic_vals: tuple[Any, ...], - is_contiguous: bool, ) -> tuple[OpResult, OpResult, OpResult]: # Apply symbolc_shape order to indices, e.g. if original mapping is # {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will @@ -635,30 +629,12 @@ def _construct_gather_scatter_indices( # expanded index. result_index = {key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)} - mask = _build_mask(emitter, index, elements_per_thread) - if mask is None: - mask_vec_type = VectorType.get( - [elements_per_thread], IntegerType.get_signless(1) - ) - mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) - - def extract0(src): - static_pos = [0] * src.type.rank - return vector_d.extract(src, static_position=static_pos, dynamic_position=[]) - - dynamic_vals_map_start = { - sym: extract0(val) - for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals) - } - if is_contiguous: - start_indices = _build_start_indices( - emitter, result_index, dynamic_vals_map_start - ) - return start_indices, None, mask + strides = strides_from_symbolic_shape(idxc, symbolc_shape, allow_mixed_shapes=True) + offsets = [] start_indices = _get_start_indices(result_index) start_indices_orig = _get_start_indices(index) - fastest_dim = _get_fastest_index(index) + need_dynamic_offsets = False for val in dynamic_vals: shape = val.type.shape @@ -669,14 +645,19 @@ def extract0(src): if shape[0] > 1: need_dynamic_offsets = True - offsets = [] - strides = strides_from_symbolic_shape(idxc, symbolc_shape, allow_mixed_shapes=True) + mask = _build_mask(emitter, index, elements_per_thread) + if mask is None: + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) + start_indices_offset = _compute_offset(start_indices, strides) for i in range(elements_per_thread): - # Update fastest dim, i.e. in case of identity mapping it will + # Update most-minor dim, i.e. in case of identity mapping it will # be equivalent to just vector.load subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] - subs[fastest_dim] = (subs[fastest_dim][0], start_indices_orig[fastest_dim] + i) + subs[-1] = (subs[-1][0], start_indices_orig[-1] + i) indices = [i.subs(subs) for i in index_mapping] # First, we build indices as if resulting gather/scatter `start_indices` @@ -701,6 +682,14 @@ def extract0(src): offsets.append(offset) + def extract0(src): + static_pos = [0] * src.type.rank + return vector_d.extract(src, static_position=static_pos, dynamic_position=[]) + + dynamic_vals_map_start = { + sym: extract0(val) + for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals) + } offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get()) if need_dynamic_offsets: # In case we need dynamic `offsets_vec`, set all `start_indices` to 0 @@ -763,7 +752,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): element_type = kb_ir_type.element_type vector_type = VectorType.get(vector_shape, element_type) input_shape = _get_symbolic_shape(memory) - if get_custom(node).has_identity_mapping(): + if mapping is None or _is_identity_mapping(mapping, input_shape=input_shape): start_indices = _build_start_indices(emitter, index) mask = _build_mask( emitter, index, cast_py_literal(emitter, elements_per_thread) @@ -790,7 +779,6 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): elements_per_thread=cast_py_literal(emitter, elements_per_thread), is_read=True, dynamic_vals=dyn_vals, - is_contiguous=get_custom(node).is_contiguous_vec(), ) zero = get_constant_attr(0, element_type) @@ -834,7 +822,9 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): index = node.index input_shape = _get_symbolic_shape(register) output_shape = _get_symbolic_shape(memory) - if get_custom(node).has_identity_mapping(): + if mapping is None or _is_identity_mapping( + mapping, input_shape=input_shape, output_shape=output_shape + ): start_indices = _build_start_indices(emitter, index) mask = _build_mask( emitter, index, cast_py_literal(emitter, elements_per_thread) @@ -859,7 +849,6 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): elements_per_thread=cast_py_literal(emitter, elements_per_thread), is_read=False, dynamic_vals=dyn_vals, - is_contiguous=get_custom(node).is_contiguous_vec(), ) if offsets_vec is None: @@ -1120,34 +1109,6 @@ def handle_exp2(source: Value) -> OpResult: return result -@handle_unary_op(reciprocal) -def handle_reciprocal(source: Value) -> OpResult: - element_type = get_type_or_element_type(source.type) - if _is_float_type(element_type): - splat_ones = DenseElementsAttr.get_splat( - source.type, get_constant_attr(1.0, element_type) - ) - ones = arith_d.ConstantOp(source.type, splat_ones) - reciprocal = arith_d.divf(ones, source) - else: - raise ValidationError( - f"Found unhandled operand type for reciprocal: {element_type}" - ) - return reciprocal - - -@handle_unary_op(abs) -def handle_abs(source: Value) -> OpResult: - element_type = get_type_or_element_type(source.type) - if _is_float_type(element_type): - abs = math_d.absf(source) - elif _is_integer_like_type(element_type): - abs = math_d.absi(source) - else: - raise ValidationError(f"Found unhandled operand type for abs: {element_type}") - return abs - - ############################################################################### # Control Flow ops ############################################################################### diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 346d7529..11344e54 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -283,7 +283,7 @@ def _expand_node( for i, arg in node.node_args.items(): arg_list = arg unpack = lambda x: x - if isinstance(arg, list): + if isinstance(arg, Sequence): if not all(is_expandable(a) for a in arg): continue else: diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 24c08fcf..5ba7064f 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -217,9 +217,11 @@ def has_gpr_offsets(node: fx.Node) -> bool: {GPR_NUM: cur_gpr_start_id} ), gpr_size, - 1 - if output_mapping[-1] == gpr_offset_dim - else simplified_index[gpr_offset_dim].stride, + ( + 1 + if output_mapping[-1] == gpr_offset_dim + else simplified_index[gpr_offset_dim].stride + ), ) updated_index_with_gpr_offset[ gpr_offset_dim @@ -274,7 +276,8 @@ def combine_derived_index( new_index = copy(src_index) for dim, new_idx in dst_index.items(): - assert dim in src_index, f"Dim {dim} not in index {src_index}" + if dim not in src_index: + continue old_idx = src_index[dim] if old_idx == new_idx: continue diff --git a/iree/turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py index 62b623a2..82793cee 100644 --- a/iree/turbine/kernel/wave/minimize_global_loads.py +++ b/iree/turbine/kernel/wave/minimize_global_loads.py @@ -125,9 +125,12 @@ def add_optimized_nodes( access_pattern: dict[IndexSymbol, IndexSequence] = custom.index for i in range(expected_number_of_loads): with custom.graph.inserting_before(custom.fx_node): - read = Read(memory, load_elems_per_thread, custom.mapping).add_to_graph( - custom.graph - ) + read = Read( + memory, + load_elems_per_thread, + custom.mapping, + custom.mapping_dynamic_vals, + ).add_to_graph(custom.graph) global_offset = ( hardware_constraint.linearized_thread_id * load_elems_per_thread + i * max_elements_per_load diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 320d9014..e2bcdb61 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -13,6 +13,7 @@ ) import torch from enum import Enum +import sympy # Input sizes B = tkl.sym.B @@ -36,442 +37,624 @@ STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD -@run_test -def test_evoformer(): - # B, BN, K2, H, K1, M, N - shape = (1, 256, 256, 4, 32, 256, 32) - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.WorkgroupConstraint(BN, BLOCK_BN, 3)] - constraints += [tkw.WorkgroupConstraint(H, BLOCK_H, 4)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - - mfma_variant = tkw.MMAType.F32_16x16x16_F16 - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes={B: 0, BN: 0, H: 0, M: 16, N: 16}, - ) - ] - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - l = tkw.IndexMapping.iterator(3) - m = tkw.IndexMapping.iterator(4) - # [B, BN, M, H, K1] -> [B, BN, H, M, K1] - q_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, M: l, K1: m}, - outputs={B: i, BN: j, H: k, M: l, K1: m}, - ) - # [B, BN, K2, H, K1] -> [B, BN, H, K2, K1] - k_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, K2: l, K1: m}, - outputs={B: i, BN: j, H: k, K2: l, K1: m}, - ) - # [B, BN, K2, H, N] -> [B, BN, H, N, K2] - v_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, N: l, K2: m}, - outputs={B: i, BN: j, H: k, N: l, K2: m}, - ) - # [B, BN, H, N, M] -> [B, BN, M, H, N] - o_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, N: l, M: m}, - outputs={B: i, BN: j, H: k, N: l, M: m}, - ) - - @tkw.wave(constraints) - def evoformer( - q: tkl.Memory[B, BN, M, H, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, BN, K2, H, K1, ADDRESS_SPACE, tkl.f16], - v: tkl.Memory[B, BN, K2, H, N, ADDRESS_SPACE, tkl.f16], - mask: tkl.Memory[B, BN, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], - bias: tkl.Memory[B, H, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, BN, M, H, N, GLOBAL_ADDRESS_SPACE, tkl.f16], - ): - c_reg = tkl.Register[B, BN, H, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, BN, H, M, tkl.f32](0.0) - init_max = tkl.Register[B, BN, H, M, tkl.f32](-1e6) - - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, BN, H, M, tkl.f32], - partial_sum: tkl.Register[B, BN, H, M, tkl.f32], - acc: tkl.Register[B, BN, H, N, M, tkl.f32], - ) -> ( - tkl.Register[B, BN, H, M, tkl.f32], - tkl.Register[B, BN, H, M, tkl.f32], - tkl.Register[B, BN, H, N, M, tkl.f32], - ): - imm_reg = tkl.Register[B, BN, H, K2, M, tkl.f32](0.0) - q_reg = tkw.read( - q, mapping=q_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD - ) - k_reg = tkw.read( - k, mapping=k_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD - ) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) - x_j = tkw.permute(inner_acc, target_shape=[B, BN, H, M, K2]) - mask_reg = tkw.read(mask, elements_per_thread=STORE_ELEMS_PER_THREAD) - casted_mask_reg = tkw.cast(mask_reg, tkl.f32) - y_j = x_j + casted_mask_reg - bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD) - casted_bias_reg = tkw.cast(bias_reg, tkl.f32) - z_j = y_j + casted_bias_reg - m_j = tkw.max(z_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(z_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read( - v, mapping=v_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD - ) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - casted = tkw.cast(res, tkl.f16) - tkw.write( - casted, c, mapping=o_mapping, elements_per_thread=STORE_ELEMS_PER_THREAD - ) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - B: shape[0], - BN: shape[1], - K2: shape[2], - H: shape[3], - K1: shape[4], - M: shape[5], - N: shape[6], - BLOCK_B: 1, - BLOCK_BN: 1, - BLOCK_H: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=False, - use_scheduling_barriers=False, - ): - torch.manual_seed(0) - q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) - v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) - print(evoformer(q, k, v, output).module_op) - - # CHECK: func.func @evoformer - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-5: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store {{.*}} - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = vector.load - # CHECK-COUNT-2: {{.*}} = arith.extf - # CHECK-COUNT-4: {{.*}} = arith.addf - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-4: {{.*}} = arith.extf - # CHECK-COUNT-4: {{.*}} = arith.addf - - -# This test sets all the dimensions except K1 to be dynamic. -# The reason why we can't set K1 to be dynamic is because K1 is the -# tile size we use for expanding the K1 MMA. We could set K1 to be -# dynamic if we tiled the K1 dimension with a tile size of BLOCK_K1. -@run_test -def test_dynamic_attention_pipelined(): - shape = (8, 128, 128, 64, 256) - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - - mfma_variant = tkw.MMAType.F32_16x16x16_F16 - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes={B: 0, M: 16, N: 16}, - ) - ] - - constraints += [tkw.Assumption(K2 > 4 * BLOCK_K2)] - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} - ) - - @tkw.wave(constraints) - def dynamic_attention_pipelined( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) - q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) - x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - K1: shape[3], - BLOCK_B: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=True, - use_scheduling_barriers=False, - dynamic_symbols=(B, M, N, K2), - dynamic_symbol_map={ - B: shape[0], - M: shape[1], - N: shape[2], - K2: shape[4], - }, - ): - torch.manual_seed(0) - q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) - v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) - print(dynamic_attention_pipelined(q, k, v, output).module_op) - - # CHECK-LABEL: func.func @dynamic_attention_pipelined - # CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}} - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} - # CHECK-COUNT-14: {{.*}} = amdgpu.mfma - # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-7: {{.*}} = amdgpu.mfma - # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma - # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-16: vector.maskedstore {{.*}} - - -@run_test -def test_attention_pipelined(): - shape = (8, 128, 128, 64, 256) - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - - mfma_variant = tkw.MMAType.F32_16x16x16_F16 - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes={B: 0, M: 16, N: 16}, - ) - ] - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} - ) - - @tkw.wave(constraints) - def base_attention_pipelined( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) - q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) - x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - BLOCK_B: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - B: shape[0], - M: shape[1], - N: shape[2], - K1: shape[3], - K2: shape[4], - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=True, - use_scheduling_barriers=False, - ): - torch.manual_seed(0) - q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) - v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) - print(base_attention_pipelined(q, k, v, output).module_op) - - # CHECK-LABEL: func.func @base_attention_pipelined - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-14: {{.*}} = amdgpu.mfma - # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-7: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma - # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} +# @run_test +# def test_evoformer(): +# # B, BN, K2, H, K1, M, N +# shape = (1, 256, 256, 4, 32, 256, 32) +# # Expose user-constraints +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.WorkgroupConstraint(BN, BLOCK_BN, 3)] +# constraints += [tkw.WorkgroupConstraint(H, BLOCK_H, 4)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes={B: 0, BN: 0, H: 0, M: 16, N: 16}, +# ) +# ] +# +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# l = tkw.IndexMapping.iterator(3) +# m = tkw.IndexMapping.iterator(4) +# # [B, BN, M, H, K1] -> [B, BN, H, M, K1] +# q_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, M: l, K1: m}, +# outputs={B: i, BN: j, H: k, M: l, K1: m}, +# ) +# # [B, BN, K2, H, K1] -> [B, BN, H, K2, K1] +# k_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, K2: l, K1: m}, +# outputs={B: i, BN: j, H: k, K2: l, K1: m}, +# ) +# # [B, BN, K2, H, N] -> [B, BN, H, N, K2] +# v_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, N: l, K2: m}, +# outputs={B: i, BN: j, H: k, N: l, K2: m}, +# ) +# # [B, BN, H, N, M] -> [B, BN, M, H, N] +# o_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, N: l, M: m}, +# outputs={B: i, BN: j, H: k, N: l, M: m}, +# ) +# +# @tkw.wave(constraints) +# def evoformer( +# q: tkl.Memory[B, BN, M, H, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, BN, K2, H, K1, ADDRESS_SPACE, tkl.f16], +# v: tkl.Memory[B, BN, K2, H, N, ADDRESS_SPACE, tkl.f16], +# mask: tkl.Memory[B, BN, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], +# bias: tkl.Memory[B, H, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, BN, M, H, N, GLOBAL_ADDRESS_SPACE, tkl.f16], +# ): +# c_reg = tkl.Register[B, BN, H, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, BN, H, M, tkl.f32](0.0) +# init_max = tkl.Register[B, BN, H, M, tkl.f32](-1e6) +# +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, BN, H, M, tkl.f32], +# partial_sum: tkl.Register[B, BN, H, M, tkl.f32], +# acc: tkl.Register[B, BN, H, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, BN, H, M, tkl.f32], +# tkl.Register[B, BN, H, M, tkl.f32], +# tkl.Register[B, BN, H, N, M, tkl.f32], +# ): +# imm_reg = tkl.Register[B, BN, H, K2, M, tkl.f32](0.0) +# q_reg = tkw.read( +# q, mapping=q_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD +# ) +# k_reg = tkw.read( +# k, mapping=k_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD +# ) +# inner_acc = tkw.mma(k_reg, q_reg, imm_reg) +# x_j = tkw.permute(inner_acc, target_shape=[B, BN, H, M, K2]) +# mask_reg = tkw.read(mask, elements_per_thread=STORE_ELEMS_PER_THREAD) +# casted_mask_reg = tkw.cast(mask_reg, tkl.f32) +# y_j = x_j + casted_mask_reg +# bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD) +# casted_bias_reg = tkw.cast(bias_reg, tkl.f32) +# z_j = y_j + casted_bias_reg +# m_j = tkw.max(z_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(z_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read( +# v, mapping=v_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD +# ) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# casted = tkw.cast(res, tkl.f16) +# tkw.write( +# casted, c, mapping=o_mapping, elements_per_thread=STORE_ELEMS_PER_THREAD +# ) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# B: shape[0], +# BN: shape[1], +# K2: shape[2], +# H: shape[3], +# K1: shape[4], +# M: shape[5], +# N: shape[6], +# BLOCK_B: 1, +# BLOCK_BN: 1, +# BLOCK_H: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=False, +# use_scheduling_barriers=False, +# ): +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# print(evoformer(q, k, v, output).module_op) +# +# # CHECK: func.func @evoformer +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-5: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store {{.*}} +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-2: {{.*}} = vector.load +# # CHECK-COUNT-2: {{.*}} = arith.extf +# # CHECK-COUNT-4: {{.*}} = arith.addf +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-4: {{.*}} = arith.extf +# # CHECK-COUNT-4: {{.*}} = arith.addf +# +# +## This test sets all the dimensions except K1 to be dynamic. +## The reason why we can't set K1 to be dynamic is because K1 is the +## tile size we use for expanding the K1 MMA. We could set K1 to be +## dynamic if we tiled the K1 dimension with a tile size of BLOCK_K1. +# @run_test +# def test_dynamic_attention_pipelined(): +# shape = (8, 128, 128, 64, 256) +# # Expose user-constraints +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes={B: 0, M: 16, N: 16}, +# ) +# ] +# +# constraints += [tkw.Assumption(K2 > 4 * BLOCK_K2)] +# +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# mapping = tkw.IndexMapping( +# num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} +# ) +# +# @tkw.wave(constraints) +# def dynamic_attention_pipelined( +# q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], +# v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, M, tkl.f32](0.0) +# init_max = tkl.Register[B, M, tkl.f32](-1e6) +# +# # This microkernel encodes the fact that if the reduction +# # dimension were tiled, then we would need to materialize a loop. +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, M, tkl.f32], +# partial_sum: tkl.Register[B, M, tkl.f32], +# acc: tkl.Register[B, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, N, M, tkl.f32], +# ): +# imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) +# q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# inner_acc = tkw.mma(k_reg, q_reg, imm_reg) +# x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) +# m_j = tkw.max(x_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(x_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# K1: shape[3], +# BLOCK_B: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=True, +# use_scheduling_barriers=False, +# dynamic_symbols=(B, M, N, K2), +# dynamic_symbol_map={ +# B: shape[0], +# M: shape[1], +# N: shape[2], +# K2: shape[4], +# }, +# ): +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# print(dynamic_attention_pipelined(q, k, v, output).module_op) +# +# # CHECK-LABEL: func.func @dynamic_attention_pipelined +# # CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}} +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} +# # CHECK-COUNT-14: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-7: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-2: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-16: vector.maskedstore {{.*}} +# +# +# @run_test +# def test_attention_pipelined(): +# shape = (8, 128, 128, 64, 256) +# # Expose user-constraints +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes={B: 0, M: 16, N: 16}, +# ) +# ] +# +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# mapping = tkw.IndexMapping( +# num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} +# ) +# +# @tkw.wave(constraints) +# def base_attention_pipelined( +# q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], +# v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, M, tkl.f32](0.0) +# init_max = tkl.Register[B, M, tkl.f32](-1e6) +# +# # This microkernel encodes the fact that if the reduction +# # dimension were tiled, then we would need to materialize a loop. +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, M, tkl.f32], +# partial_sum: tkl.Register[B, M, tkl.f32], +# acc: tkl.Register[B, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, N, M, tkl.f32], +# ): +# imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) +# q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# inner_acc = tkw.mma(k_reg, q_reg, imm_reg) +# x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) +# m_j = tkw.max(x_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(x_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# BLOCK_B: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# B: shape[0], +# M: shape[1], +# N: shape[2], +# K1: shape[3], +# K2: shape[4], +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=True, +# use_scheduling_barriers=False, +# ): +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# print(base_attention_pipelined(q, k, v, output).module_op) +# +# # CHECK-LABEL: func.func @base_attention_pipelined +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-14: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-7: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-2: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} +# +# +# @run_test +# def test_flash_decoding(): +# shape = (8, 128, 128, 64, 256) +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# +# class Phase(Enum): +# QK = (0,) +# SOFTMAX_V = (1,) +# +# def get_constraints(phase: Phase) -> list[tkw.Constraint]: +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# if phase == Phase.QK: +# constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] +# constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / 2)] +# vector_shapes = {B: 0, M: 16, K2: 16} +# else: +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# vector_shapes = {B: 0, M: 16, N: 16} +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes=vector_shapes, +# ) +# ] +# return constraints +# +# # The first kernel computes Q @ K.T. +# @tkw.wave(get_constraints(Phase.QK)) +# def qk_kernel( +# q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) +# q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# acc = tkw.mma(k_reg, q_reg, c_reg) +# x_j = tkw.permute(acc, target_shape=[B, M, K2]) +# tkw.write(x_j, c, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# # The second kernel computes the softmax and V @ softmax(Q @ K.T). +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# mapping = tkw.IndexMapping( +# num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} +# ) +# +# @tkw.wave(get_constraints(Phase.SOFTMAX_V)) +# def softmax_v_kernel( +# qk: tkl.Memory[B, M, K2, ADDRESS_SPACE, tkl.f32], +# v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, M, tkl.f32](0.0) +# init_max = tkl.Register[B, M, tkl.f32](-1e6) +# +# # This microkernel encodes the fact that if the reduction +# # dimension were tiled, then we would need to materialize a loop. +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, M, tkl.f32], +# partial_sum: tkl.Register[B, M, tkl.f32], +# acc: tkl.Register[B, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, N, M, tkl.f32], +# ): +# x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) +# m_j = tkw.max(x_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(x_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# BLOCK_B: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# B: shape[0], +# M: shape[1], +# N: shape[2], +# K1: shape[3], +# K2: shape[4], +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# qkt = torch.zeros(shape[0], shape[1], shape[4], dtype=torch.float32) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=False, +# use_scheduling_barriers=False, +# ): +# print(qk_kernel(q, k, qkt).module_op) +# +# # CHECK: func.func @qk_kernel +# # CHECK-NOT: {{.*}} = scf.for +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-8: {{.*}} = vector.load +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-2: vector.store +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=False, +# use_scheduling_barriers=False, +# ): +# print(softmax_v_kernel(qkt, v, output).module_op) +# +# # CHECK: func.func @softmax_v_kernel +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle +# # CHECK-COUNT-4: {{.*}} = arith.subf +# # CHECK-COUNT-4: {{.*}} = math.exp2 +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle +# # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +# @run_test -def test_flash_decoding(): - shape = (8, 128, 128, 64, 256) +def test_flash_paged_decoding(): + shape = (32, 6, 128, 128, 1025) + + R0 = 4097 + R1 = 8196 + SEQ_LEN = 1024 + NUM_HEADS = 6 + HEAD_TILE_SIZE = 16 + BK = shape[0] * NUM_HEADS mfma_variant = tkw.MMAType.F32_16x16x16_F16 class Phase(Enum): @@ -481,95 +664,189 @@ class Phase(Enum): def get_constraints(phase: Phase) -> list[tkw.Constraint]: constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + m_ratio = 1 + n_ratio = 2 if phase == Phase.QK: + # Distribute the batch, head and sequence length dimensions across the workgroups. constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] - constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / 2)] + constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / n_ratio)] vector_shapes = {B: 0, M: 16, K2: 16} else: constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / n_ratio)] vector_shapes = {B: 0, M: 16, N: 16} constraints += [ tkw.HardwareConstraint( threads_per_wave=64, - waves_per_block=(2, 2, 1), + waves_per_block=(m_ratio, n_ratio, 1), mma_type=mfma_variant, vector_shapes=vector_shapes, ) ] return constraints - # The first kernel computes Q @ K.T. + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + l = tkw.IndexMapping.iterator(3) + d0 = tkw.IndexMapping.dynamic_val(0) + + # Load a specific element from the request_to_tokens matrix. + # The request_to_tokens matrix has shape [R0, R1] and we are loading a single element. + request_to_tokens_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={B: d0, K2: j}, + outputs={B: i, K2: j}, + dynamic_val_mappings={B: i}, + ) + + # The k-v cache has logical dimensions [B, M, K1] and we + # broadcast it to [B, K2, K1]. + # TODO: Explicitly adding WORKGROUP_0 here only works if we + # are not distribution to waves. In general, we will want to replace + # this with the index corresponding to dimension M. + k_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={ + B: d0 + (WORKGROUP_0 / sympy.ceiling(NUM_HEADS / HEAD_TILE_SIZE)), + K2: sympy.Integer(0), + K1: k, + }, + # Potentially, something like + # constants={M: l}, and then we add M to the inputs. + outputs={B: i, K2: j, K1: k}, + dynamic_val_mappings={B: i // LOAD_ELEMS_PER_THREAD, K2: j}, + ) + + # Broadcast the offset along the batch dimension. + output_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={B: i, M: j, K2: k}, + outputs={B: d0, M: j, K2: k}, + dynamic_val_mappings={B: i // STORE_ELEMS_PER_THREAD}, + ) + + # The first kernel computes Q @ K.T, after loading K from the K-cache. + # Say the batch dimension B = 32, num heads M = 6 and head dimension K1 = 128. + # The shape of the Q matrix is [32, 6, 128] / [16, 128] (BLOCK_B: 1, BLOCK_M: 16). + # The K cache has a much larger first dimension that could be on the order of + # ~O(10^7). Since we wil always be accessing this dimension using a dynamic variable, + # and will always be loading [K2, K1] from the K cache, we can set the first dimension + # to be B. # After loading the K matrix, it would be of shape : [32, ?, 128] / [64, 128] + # (with a BLOCK_K2 = 64) where we are using the same batch dimension for all the K vectors. + + # The request to tokens matrix maps from logical to physical indices in the K cache. + # In general it has a shape [R0, R1], but since we are only loadin B offsets we represent it + # with shape [B] and use R0, R1 in the mapping to access the appropriate index. In general, + # we can always represent an n-D matrix as a 1-D matrix. + # The request indices and offsets have shape [B] and are used to index into the request_to_tokens matrix. + + # Finally, the output matrix is of shape [B, M, K2] where K2 is the maximum sequence length. + # (So it would be of shape [32, 6, 128] / [16, 64]). + + # Define the physical storage layout of the input memory buffers. They do not have to + # be identical to the logical storage layout specified by the symbolic dimensions. + q_layout = {"stride": (SEQ_LEN, K1, 1)} + k_layout = {"stride": (K1, K1, 1), "shape": (BK, 1, K1)} + req_to_tokens_layout = {"stride": (R1, 1), "shape": (R0, R1)} + @tkw.wave(get_constraints(Phase.QK)) def qk_kernel( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16, q_layout], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16, k_layout], + req_to_tokens: tkl.Memory[ + B, K2, GLOBAL_ADDRESS_SPACE, tkl.i32, req_to_tokens_layout + ], + request_indices: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + request_offsets: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + output: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], ): c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + req_idx_reg = tkw.read(request_indices, elements_per_thread=1) q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_offsets = tkw.read( + req_to_tokens, + elements_per_thread=1, + mapping=request_to_tokens_mapping, + mapping_dynamic_vals=(req_idx_reg,), + ) + k_reg = tkw.read( + k, + elements_per_thread=LOAD_ELEMS_PER_THREAD, + mapping=k_mapping, + mapping_dynamic_vals=(k_offsets,), + ) acc = tkw.mma(k_reg, q_reg, c_reg) x_j = tkw.permute(acc, target_shape=[B, M, K2]) - tkw.write(x_j, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - # The second kernel computes the softmax and V @ softmax(Q @ K.T). - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} - ) - - @tkw.wave(get_constraints(Phase.SOFTMAX_V)) - def softmax_v_kernel( - qk: tkl.Memory[B, M, K2, ADDRESS_SPACE, tkl.f32], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc + req_off_reg = tkw.read(request_offsets, elements_per_thread=1) + tkw.write( + x_j, + output, + elements_per_thread=STORE_ELEMS_PER_THREAD, + mapping=output_mapping, + mapping_dynamic_vals=(req_off_reg,), + ) - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + ## The second kernel computes the softmax and V @ softmax(Q @ K.T). + # i = tkw.IndexMapping.iterator(0) + # j = tkw.IndexMapping.iterator(1) + # k = tkw.IndexMapping.iterator(2) + # mapping = tkw.IndexMapping( + # num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + # ) + + # @tkw.wave(get_constraints(Phase.SOFTMAX_V)) + # def softmax_v_kernel( + # qk: tkl.Memory[B, H, M, K2, ADDRESS_SPACE, tkl.f32], + # v: tkl.Memory[BV, K2, ADDRESS_SPACE, tkl.f16], + # request_to_tokens: tkl.Memory[R0, R1, ADDRESS_SPACE, tkl.f16], + # request_indices: tkl.Memory[B, ADDRESS_SPACE, tkl.i32], + # request_offsets: tkl.Memory[B, ADDRESS_SPACE, tkl.i32], + # output: tkl.Memory[B, H, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + # ): + # c_reg = tkl.Register[B, H, N, M, tkl.f32](0.0) + # init_sum = tkl.Register[B, H, M, tkl.f32](0.0) + # init_max = tkl.Register[B, H, M, tkl.f32](-1e6) + + # @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + # def repeat( + # partial_max: tkl.Register[B, H, M, tkl.f32], + # partial_sum: tkl.Register[B, H, M, tkl.f32], + # acc: tkl.Register[B, H, N, M, tkl.f32], + # ) -> ( + # tkl.Register[B, H, M, tkl.f32], + # tkl.Register[B, H, M, tkl.f32], + # tkl.Register[B, H, N, M, tkl.f32], + # ): + # x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) + # m_j = tkw.max(x_j, partial_max, dim=K2) + # e_delta_max = tkw.exp2(partial_max - m_j) + # e_delta = tkw.exp2(x_j - m_j) + # e_init = partial_sum * e_delta_max + # d_j = tkw.sum(e_delta, e_init, dim=K2) + # imm_f16 = tkw.cast(e_delta, tkl.f16) + # v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # new_acc = acc * e_delta_max + # acc = tkw.mma(v_reg, imm_f16, new_acc) + # return m_j, d_j, acc + + # # repeat represents the results of the loop + # res_max, res_sum, res_mm = repeat + # res = res_mm / res_sum + # tkw.write( + # res, output, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD + # ) hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), BLOCK_B: 1, - BLOCK_M: 64, + BLOCK_M: 16, # Read heads in blocks of 16 BLOCK_N: 64, - BLOCK_K2: 32, + BLOCK_K2: 64, # Sequence length in blocks of 64 B: shape[0], M: shape[1], N: shape[2], @@ -619,28 +896,28 @@ def repeat( # CHECK-COUNT-8: {{.*}} = amdgpu.mfma # CHECK-COUNT-2: vector.store - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=False, - use_scheduling_barriers=False, - ): - print(softmax_v_kernel(qkt, v, output).module_op) - - # CHECK: func.func @softmax_v_kernel - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-4: {{.*}} = gpu.shuffle - # CHECK-COUNT-4: {{.*}} = arith.subf - # CHECK-COUNT-4: {{.*}} = math.exp2 - # CHECK-COUNT-4: {{.*}} = gpu.shuffle - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + # with tk.gen.TestLaunchContext( + # hyperparams, + # canonicalize=True, + # run=False, + # run_bench=False, + # schedule=False, + # use_scheduling_barriers=False, + # ): + # print(softmax_v_kernel(qkt, v, output).module_op) + + # CHECK: func.func @softmax_v_kernel + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-1: {{.*}} = vector.load + # CHECK-COUNT-1: vector.store + # CHECK-COUNT-1: {{.*}} = vector.load + # CHECK-COUNT-1: vector.store + # CHECK-COUNT-4: {{.*}} = vector.load + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-4: {{.*}} = arith.subf + # CHECK-COUNT-4: {{.*}} = math.exp2 + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma @run_test diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 3c8a57df..c61ecd46 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -23,12 +23,15 @@ get_mfma_store_elems_per_thread, device_randn, device_zeros, + device_randint, + to_default_device, ) from iree.turbine.kernel.wave.constraints import MMAType import os import json from torch.testing import assert_close, assert_allclose from enum import Enum +import sympy _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) @@ -53,6 +56,7 @@ perf_test(x) for x in default_test_shapes["test_attention"] ] +default_test_shapes["test_paged_decoding"] = [(32, 6, 0, 128, 1024)] user_specified_test_shapes = "" @@ -1136,3 +1140,300 @@ def repeat( f.write(mb_sv.module_op.get_asm()) assert_allclose(output, torch_ref) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_paged_decoding")) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("dynamic_dims", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x16_F16, + ], +) +def testPagedFlashDecoding( + shape: tuple[int], + enable_scheduling: bool, + dynamic_dims: bool, + mfma_variant: MMAType, + request, +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + class Phase(Enum): + QK = (0,) + SOFTMAX_V = (1,) + + def get_constraints(phase: Phase) -> list[tkw.Constraint]: + if mfma_variant == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + ratio_m = 1 + ratio_n = 2 + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + # constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)] + if phase == Phase.QK: + constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] + constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / ratio_n)] + vector_shapes = {B: 0, M: Mvec, K2: Nvec} + else: + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / ratio_n)] + vector_shapes = {B: 0, M: Mvec, N: Nvec} + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(ratio_m, ratio_n, 1), + mma_type=mfma_variant, + vector_shapes=vector_shapes, + ) + ] + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + return constraints + + # Shape of logical to physical mapping table. + R0 = 4097 + R1 = 8196 + BATCH = shape[0] + NUM_HEADS = shape[1] + HEAD_DIM = shape[3] + SEQ_LEN = shape[4] + HEAD_TILE_SIZE = 16 + BK = BATCH * NUM_HEADS + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + l = tkw.IndexMapping.iterator(3) + d0 = tkw.IndexMapping.dynamic_val(0) + + # Load a specific element from the request_to_tokens matrix. + # The request_to_tokens matrix has shape [R0, R1] and we are loading a single element. + request_to_tokens_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={B: d0, K2: j}, + outputs={B: i, K2: j}, + dynamic_val_mappings={B: i}, + ) + + # Broadcast the offset along the batch dimension. + k_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={ + B: d0 + WORKGROUP_0 / sympy.ceiling(NUM_HEADS / HEAD_TILE_SIZE), + K2: sympy.Integer(0), + K1: k, + }, + outputs={B: i, K2: j, K1: k}, + dynamic_val_mappings={B: i // LOAD_ELEMS_PER_THREAD, K2: j}, + ) + + # Broadcast the offset along the batch dimension. + output_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={B: i, M: j, K2: k}, + outputs={B: d0, M: j, K2: k}, + dynamic_val_mappings={B: i // STORE_ELEMS_PER_THREAD}, + ) + + q_layout = {"stride": (SEQ_LEN, K1, 1)} + k_layout = {"stride": (K1, K1, 1), "shape": (BK, 1, K1)} + req_to_tokens_layout = {"stride": (R1, 1), "shape": (R0, R1)} + + # The first kernel computes K @ Q.T. + @tkw.wave(get_constraints(Phase.QK)) + def qk_kernel( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16, q_layout], + k_cache: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16, k_layout], + request_to_tokens: tkl.Memory[ + B, K2, GLOBAL_ADDRESS_SPACE, tkl.i32, req_to_tokens_layout + ], + request_indices: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + request_offsets: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + output: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + req_idx_reg = tkw.read(request_indices, elements_per_thread=1) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_offsets = tkw.read( + request_to_tokens, + elements_per_thread=1, + mapping=request_to_tokens_mapping, + mapping_dynamic_vals=(req_idx_reg,), + ) + k_reg = tkw.read( + k_cache, + elements_per_thread=LOAD_ELEMS_PER_THREAD, + mapping=k_mapping, + mapping_dynamic_vals=(k_offsets,), + ) + acc = tkw.mma(k_reg, q_reg, c_reg) + x_j = tkw.permute(acc, target_shape=[B, M, K2]) + req_off_reg = tkw.read(request_offsets, elements_per_thread=1) + tkw.write( + x_j, + output, + elements_per_thread=STORE_ELEMS_PER_THREAD, + mapping=output_mapping, + mapping_dynamic_vals=(req_off_reg,), + ) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 16, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + } + hyperparams.update(get_default_scheduling_params()) + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + torch.manual_seed(0) + + # Construct query matrix with correct stride. + desired_shape = (BATCH, NUM_HEADS, HEAD_DIM) + desired_stride = (SEQ_LEN, HEAD_DIM, 1) + space_required = sum([a * b for a, b in zip(desired_shape, desired_stride)]) + q = device_randn(space_required, dtype=torch.float16) + q = torch.as_strided(q, desired_shape, desired_stride) + + # Construct synthetic page tables for key matrix. + total_entries = BATCH * NUM_HEADS + request_to_tokens = device_zeros(R0, R1, dtype=torch.int32) + request_to_tokens[0:BATCH, 0:SEQ_LEN] = ( + torch.arange(BATCH * SEQ_LEN).reshape(BATCH, SEQ_LEN) + 1 + ) + request_to_tokens[0:BATCH, SEQ_LEN] = torch.arange( + BATCH * SEQ_LEN, BATCH * (SEQ_LEN + 1) + ) + request_indices = to_default_device(torch.arange(BATCH, dtype=torch.int32)) + request_offsets = to_default_device(torch.arange(BATCH, dtype=torch.int32)) + desired_stride = (HEAD_DIM, HEAD_DIM, 1) + desired_shape = (BATCH * NUM_HEADS, 1, HEAD_DIM) + space_required = sum([a * b for a, b in zip(desired_shape, desired_stride)]) + k_cache = device_randn(space_required, dtype=torch.float16) + k_cache = torch.as_strided(k_cache, desired_shape, desired_stride) + + # def extract_page_table_entries(request_to_tokens, request_indices, request_offsets): + # entries = [] + # for request in request_indices: + # token_idx = request_to_tokens[request] + # entries.append(k_cache[token_idx]) + # # TODO: Broadcast entries to shape [B, K2, K1]. + # return torch.cat(entries, dim=0) + + # k = extract_page_table_entries(request_to_tokens, request_indices, request_offsets) + qk = device_zeros(shape[0], shape[1], shape[4], dtype=torch.float32) + + # v = device_randn(shape[0], shape[4], shape[2], dtype=torch.float16) + # output = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape[3]) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + # TODO: Add scaling of QK as part of kernel. + mb_qk = qk_kernel( + q * dk_sqrt * log2e, + k_cache, + request_to_tokens, + request_indices, + request_offsets, + qk, + ) + + breakpoint() + # torch_ref = torch.matmul(q, k.permute([0, 2, 1])) * dk_sqrt * log2e + # assert_allclose(qk, torch_ref.permute([0, 2, 1])) + + # with tk.gen.TestLaunchContext( + # hyperparams, + # canonicalize=True, + # run=True, + # run_bench=run_bench, + # run_config=config, + # schedule=enable_scheduling, + # use_scheduling_barriers=enable_scheduling_barriers, + # dynamic_symbols=dynamic_symbols, + # dynamic_symbols_map=dynamic_symbols_map, + # ): + # # TODO: Add variant of non-transposed V attention kernel. + # mb_sv = softmax_v_kernel(qk, v.permute([0, 2, 1]), output) + + # torch_ref = torch.nn.functional.scaled_dot_product_attention( + # q, k, v, attn_mask=None + # ) + + if test_dump_generated_mlir: + filename = f"wave_qk_kernel_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_qk.module_op.get_asm()) + # filename = f"wave_softmax_v_kernel_{'x'.join(map(str, shape))}.mlir" + # with open(filename, "w") as f: + # f.write(mb_sv.module_op.get_asm()) + + # assert_allclose(output, torch_ref)