From 5c2c6e99df46e4dad9e5b9cb5c898378a9bfc823 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 22 Apr 2024 20:52:28 +0000 Subject: [PATCH 1/5] [SME] Add scalable fp16->fp32 dense schedule This commit extends the functionality of the SME dense and matmul schedules to support operations with fp16 inputs and an fp32 output, where `transpose_a=False` and `transpose_b=True`. For convenience, it also adds a utility called `get_vscale_factor` which created the correct multiplier for `vscale` given a data type, reflecting ideas from an early design of the [SVE](https://github.com/apache/tvm-rfcs/pull/104) RFC. Change-Id: I8c00bc6baf2df6015fa41200a238781126c73589 --- python/tvm/relay/op/strategy/arm_cpu.py | 25 +- python/tvm/testing/aot.py | 2 + python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 18 +- python/tvm/tir/tensor_intrin/arm_cpu.py | 219 ++++++++++++++++-- python/tvm/topi/arm_cpu/dense_alter_op.py | 24 +- python/tvm/topi/arm_cpu/matmul.py | 124 ++++++++-- .../codegen/test_target_codegen_aarch64.py | 6 +- tests/python/relay/aot/aprofile_aem.mk | 1 + .../relay/strategy/arm_cpu/test_dense.py | 17 +- .../relay/strategy/arm_cpu/test_matmul.py | 37 +-- .../python/relay/test_pass_alter_op_layout.py | 32 ++- 12 files changed, 410 insertions(+), 97 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 9974d2691d4b..5e94b38772a8 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -21,7 +21,6 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import re -import tvm from tvm import relay, topi, tir from tvm.tir.schedule.analysis import has_block @@ -684,9 +683,9 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target): if ( target.features.has_sme - and data.dtype in ["float32"] - and weight.dtype in ["float32"] - and out_type.dtype in ["float32"] + and data.dtype in ["float32", "float16"] + and weight.dtype == data.dtype + and out_type.dtype == "float32" # The schedule uses tensorization which does not work when the # reduction axis has unit iters. See # https://github.com/apache/tvm/issues/16566 @@ -724,10 +723,12 @@ def matmul_strategy_arm_cpu(attrs, inputs, out_type, target): if ( target.features.has_sme - and data.dtype in ["float32"] - and weight.dtype in ["float32"] - and out_type.dtype in ["float32"] - and not (attrs.transpose_a or attrs.transpose_b) + and data.dtype in ["float32", "float16"] + and weight.dtype == data.dtype + and out_type.dtype == "float32" + and not attrs.transpose_a + and not (data.dtype == "float16" and not attrs.transpose_b) + and not (data.dtype == "float32" and attrs.transpose_b) and len(data.shape) == 2 # The schedule uses tensorization which does not work when the # reduction axis has unit iters. See @@ -796,9 +797,13 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool: """ Strategy for arm_cpu STIR schedules. """ - current_target = tvm.target.Target.current() + matmul_block = None + if has_block(sch, "T_matmul_NN"): + matmul_block = sch.get_block("T_matmul_NN") + elif has_block(sch, "T_matmul_NT"): + matmul_block = sch.get_block("T_matmul_NT") - if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"): + if matmul_block and sch.get(matmul_block).annotations.get("schedule_type", "") == "sme": topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) return True diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 609c429c2211..36fdad789d96 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -45,6 +45,8 @@ "uint16": "uint16_t", "int32": "int32_t", "uint32": "uint32_t", + # See: https://gcc.gnu.org/onlinedocs/gcc/Half-Precision.html + "float16": "_Float16", "float32": "float", } diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 24ba4ccd2e58..4ecac98cde20 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,7 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic -from .op import vscale, get_active_lane_mask +from .op import vscale, get_active_lane_mask, get_vscale_factor from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index db52bec598b1..c086bebafa4c 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Operators used in TIR expression.""" -from typing import Any, Optional +from typing import Any, Optional, Union import tvm._ffi from tvm.ir import Array, Op, PrimExpr @@ -3370,6 +3370,22 @@ def get_active_lane_mask(dtype, base, limit): return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) +def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr: + """ + Create a datatype dependent scalable expression. + + Parameters + ---------- + dtype : tvm.DataType + Element data type. + min_size : int + The minimum size of the scalable vector. + """ + if isinstance(dtype, str): + dtype = tvm.DataType(dtype) + return min_size // dtype.bits * vscale() + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 90af1e05b172..9ffdbc659729 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring,unused-import """Intrinsics for ARM tensorization.""" + +from tvm import tir from tvm.script import tir as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder.tir import prim_func as build_prim_func @@ -167,7 +169,14 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None: return dot_prod_desc, dot_prod_impl -def get_sme_transpose_interleave_2svlx2svl_intrin(): +def _create_ptrue_mask(dtype): + """ + Creates a mask that enables all lanes of a scalable vector. + """ + return T.broadcast(T.IntImm("int1", 1), tir.get_vscale_factor(dtype)) + + +def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): """ Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length) using the Scalable Matrix Extension (SME). @@ -176,8 +185,6 @@ def get_sme_transpose_interleave_2svlx2svl_intrin(): then storing the columns. The SME accumulator tile is divided into a series of sub-tiles which must be loaded to / stored from independently. - Note: currently only supports the fp32 datatype. - Example ------- An example case for float32. In this instance the accumulator tile is divided into 4 @@ -206,7 +213,7 @@ def get_sme_transpose_interleave_2svlx2svl_intrin(): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = 4 * T.vscale() + SVF = tir.get_vscale_factor("float32") SVF2 = 2 * SVF @T.prim_func @@ -222,7 +229,6 @@ def desc(a: T.handle, a_t: T.handle) -> None: A_t[v_k, v_m] = A[v_m, v_k] def impl(): - # Accumulation sub-tile count. For fp32 it is 4 sub_tile_count = 4 with IRBuilder() as ib: @@ -242,7 +248,7 @@ def impl(): ) # Disable predication - ptrue = T.broadcast(T.IntImm("int1", 1), T.vscale() * 4) + ptrue = _create_ptrue_mask("float32") with T.block("root"): T.reads(A[0:SVF2, 0:SVF2]) @@ -295,7 +301,151 @@ def impl(): return desc, impl() -def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K): +def get_sme_transpose_interleave_block2_2svl_fp16_intrin(): + # pylint: disable=line-too-long + """ + Transpose and block pack a matrix of size 2SVL x 1SVL (where 'SVL' is the Scalable Vector + Length for the fp16 datatype) using the Scalable Matrix Extension (SME). + + Rows of the fp16 input matrix are loaded into the accumulator tile and columns are stored + as fp32 SVL length vectors to the output matrix. When loading, the accumulator tile is + interpreted to be of shape 2 * 8 * vscale x 8 * vscale. When storing, we interpret the + accumulator tile to be of shape 2 * 4 * vscale x 2 * 4 * vscale. + + Example + ------- + In the fp16 instance, the accumulator tile consists of two sub-tiles numbered 0-1. Rows + of A are loaded onto the accumulator tile by interleaving rows in the first half (0, SVL//2] + of the tile and rows in the second half (SVL//2, SVL]. Columns of fp32 values are stored + into the output buffer. The fp32 store is used to group pairs of consecutive values together, + resulting in the arrangement displayed below. + + A: Accumulator tile: + +----------------+ +----------------+ + |-------0a-------| |-------0a-------| + |-------0b-------| |-------0x-------| + | ... | |-------0b-------| A_t: + |-------0x-------| |-------0y-------| +------------------------------------------------+ + |-------0y-------| | ... | |0a.0 0a.1 0b.0 0b.1 | 1a.0 1a.1 1b.0 1b.1 | + | ... | ld1h.horiz | | st1w.vert |0x.0 0x.1 0y.0 0y.1 | 1x.0 1x.1 1y.0 1y.1 | + |================| ====> |================| ====> |0a.2 0a.3 0b.2 0b.3 ...| 1a.2 1a.3 1b.2 1b.3 ...| + |-------1a-------| |-------1a-------| |0x.2 0x.3 0y.2 0y.3 | 1x.2 1x.3 1y.2 1y.3 | + |-------1b-------| |-------1x-------| |... ... ... ... | ... ... ... ... | + | ... | |-------1b-------| +------------------------------------------------+ + |-------1x-------| |-------1y-------| + |-------1y-------| | ... | + | ... | | | + +----------------+ +----------------+ + + In the A_t output matrix in the diagram above, .x is used to denote the offset into the + labelled row. + + Returns + ------- + intrin : TensorIntrin + The SME TensorIntrin that can be used in tensorizing a schedule. + + """ + # pylint: enable=line-too-long + SVF = tir.get_vscale_factor("float16") + SVF2 = 2 * SVF + + @T.prim_func + def desc(a: T.handle, a_t: T.handle) -> None: + A = T.match_buffer(a, (SVF2, SVF), dtype="float16", offset_factor=1) + A_t = T.match_buffer(a_t, (SVF, SVF2), dtype="float16", offset_factor=1) + with T.block("root"): + T.reads(A[0:SVF2, 0:SVF]) + T.writes(A_t[0:SVF, 0:SVF2]) + for k, m in T.grid(SVF, SVF2): + with T.block("transpose"): + v_m, v_k = T.axis.remap("SS", [m, k]) + A_t[v_k, v_m] = A[v_m, v_k] + + def impl(): + with IRBuilder() as ib: + with build_prim_func(): + a = T.arg("a", T.handle()) + a_t = T.arg("a_t", T.handle()) + + A = T.match_buffer( + a, (SVF2, SVF), "float16", offset_factor=1, strides=[T.int32(), 1] + ) + A_t = T.match_buffer( + a_t, (SVF, SVF2), "float16", offset_factor=1, strides=[T.int32(), 1] + ) + + ptrue_fp16 = _create_ptrue_mask("float16") + ptrue_fp32 = _create_ptrue_mask("float32") + + with T.block("root"): + T.reads(A[0:SVF2, 0:SVF]) + T.writes(A_t[0:SVF, 0:SVF2]) + + # Load rows of the input matrix + with T.serial(SVF // 2) as slice_idx: + for sub_tile_idx in range(2): + offset = slice_idx * A.strides[0] + (SVF * A.strides[0] * sub_tile_idx) + input_ptr = A.access_ptr("r", offset=offset) + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.ld1h.horiz", + T.uint32(4), + ptrue_fp16, + input_ptr, + sub_tile_idx, + slice_idx * 2, + ) + ) + input_ptr = A.access_ptr("r", offset=offset + (SVF // 2) * A.strides[0]) + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.ld1h.horiz", + T.uint32(4), + ptrue_fp16, + input_ptr, + sub_tile_idx, + slice_idx * 2 + 1, + ) + ) + + # Store columns to the output matrix + with T.serial(SVF // 2) as slice_idx: + for sub_tile_idx in range(2): + offset = slice_idx * 2 * A_t.strides[0] + (SVF * sub_tile_idx) + output_ptr = A_t.access_ptr("w", offset=offset) + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.st1w.vert", + T.uint32(4), + ptrue_fp32, + output_ptr, + sub_tile_idx, + slice_idx, + ) + ) + output_ptr = A_t.access_ptr("w", offset=offset + A_t.strides[0]) + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.st1w.vert", + T.uint32(4), + ptrue_fp32, + output_ptr, + sub_tile_idx + 2, + slice_idx, + ) + ) + + return ib.get() + + return desc, impl() + + +def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype): """ Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length using outer product operations from the Scalable Matrix Extension (SME). @@ -312,7 +462,6 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K): repeated K times. Finally, the results of the accumulation are stored. Note: The input tensor 'A' must be transpose-interleaved. - Note: Currently only supports the fp32 datatype. Example ------- @@ -383,13 +532,16 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = 4 * T.vscale() + SVF = tir.get_vscale_factor("float32") SVF2 = 2 * SVF + fmopa_intrin = ( + "llvm.aarch64.sme.mopa" if in_dtype == "float32" else "llvm.aarch64.sme.mopa.wide" + ) @T.prim_func def desc(a: T.handle, b: T.handle, c: T.handle): - A = T.match_buffer(a, (K, SVF2), dtype="float32", offset_factor=1) - B = T.match_buffer(b, (K, SVF2), dtype="float32", offset_factor=1) + A = T.match_buffer(a, (K, SVF2), dtype=in_dtype, offset_factor=1) + B = T.match_buffer(b, (K, SVF2), dtype=in_dtype, offset_factor=1) C = T.match_buffer(c, (SVF2, SVF2), dtype="float32", offset_factor=1) with T.block("root"): @@ -398,10 +550,9 @@ def desc(a: T.handle, b: T.handle, c: T.handle): for m, n, k in T.grid(SVF2, SVF2, K): with T.block("gemm"): v_m, v_n, v_k = T.axis.remap("SSR", [m, n, k]) - C[v_m, v_n] += A[v_k, v_m] * B[v_k, v_n] + C[v_m, v_n] += T.Cast("float32", A[v_k, v_m]) * T.Cast("float32", B[v_k, v_n]) def impl(): - # Accumulation sub-tile count. For fp32 it is 4 sub_tile_count = 4 with IRBuilder() as ib: @@ -410,24 +561,33 @@ def impl(): b = T.arg("b", T.handle()) c = T.arg("c", T.handle()) - A = T.match_buffer(a, (K, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1]) - B = T.match_buffer(b, (K, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1]) + A = T.match_buffer(a, (K, SVF2), in_dtype, offset_factor=1, strides=[T.int32(), 1]) + B = T.match_buffer(b, (K, SVF2), in_dtype, offset_factor=1, strides=[T.int32(), 1]) C = T.match_buffer( c, (SVF2, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1] ) - ptrue = T.broadcast(T.IntImm("int1", 1), T.vscale() * 4) + ptrue = _create_ptrue_mask(in_dtype) with T.block("root"): T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2]) T.writes(C[0:SVF2, 0:SVF2]) # Iterate over the reduction axis applying outer product and accumulate - with T.serial(K) as k: - a_low = T.BufferLoad(A, [k, T.Ramp(0, 1, T.vscale() * 4)]) - a_high = T.BufferLoad(A, [k, T.Ramp(SVF, 1, T.vscale() * 4)]) - b_low = T.BufferLoad(B, [k, T.Ramp(0, 1, T.vscale() * 4)]) - b_high = T.BufferLoad(B, [k, T.Ramp(SVF, 1, T.vscale() * 4)]) + rows_per_iter = 1 if in_dtype == "float32" else 2 + with T.serial(T.ceildiv(K, rows_per_iter)) as k: + k_row = k * rows_per_iter + in_dtype_svf = tir.get_vscale_factor(in_dtype) + + a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]) + b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]) + + if in_dtype == "float32": + a_high = T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]) + b_high = T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]) + else: + a_high = T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]) + b_high = T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]) input_combinations = [ (a_low, b_low), @@ -443,7 +603,7 @@ def impl(): T.evaluate( T.call_llvm_intrin( "void", - "llvm.aarch64.sme.mopa.nxv4f32", + fmopa_intrin, T.uint32(5), sub_tile, ptrue, @@ -466,7 +626,7 @@ def impl(): "void", "llvm.aarch64.sme.st1w.horiz", T.uint32(4), - ptrue, + _create_ptrue_mask("float32"), output_ptr, T.int32(sub_tile_idx), T.int32(slice_idx), @@ -520,14 +680,23 @@ def impl(c: T.handle) -> None: TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", "int32")) ARM_SME_INIT = "sme_init" -ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE = "sme_2svlx2svl_transpose_interleave" +ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE = "sme_2svlx2svl_fp32_transpose_interleave" +ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE = ( + "sme_block2_2svlx1svl_fp16_transpose_interleave" +) ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA = "sme_2svlx2svl_gemm_interleaved_mopa" + # The following tensor intrinsics use LLVM intrinsics that are only available # in versions of LLVM >= 15. Installations with older versions of LLVM will # not be able to use them. if llvm_version_major() >= 15: TensorIntrin.register( - ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, *get_sme_transpose_interleave_2svlx2svl_intrin() + ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, + *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(), + ) + TensorIntrin.register( + ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE, + *get_sme_transpose_interleave_block2_2svl_fp16_intrin(), ) TensorIntrin.register(ARM_SME_INIT, *get_sme_init_intrin()) diff --git a/python/tvm/topi/arm_cpu/dense_alter_op.py b/python/tvm/topi/arm_cpu/dense_alter_op.py index 208b923e68e4..398f8398af1c 100644 --- a/python/tvm/topi/arm_cpu/dense_alter_op.py +++ b/python/tvm/topi/arm_cpu/dense_alter_op.py @@ -27,6 +27,8 @@ @dense_alter_layout.register("arm_cpu") def _alter_dense(attrs, inputs, tinfos, out_type): + from tvm.relay.op.nn import _make # pylint: disable=import-outside-toplevel + target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current @@ -52,23 +54,25 @@ def _alter_dense(attrs, inputs, tinfos, out_type): ), "matmul_sme.arm_cpu requires weights be a Relay Constant" weight_dtype = tinfos[1].dtype - weight_data = inputs[1].data.numpy() - interleaved = weight_data.transpose() - encoded_weight = relay.const(interleaved, weight_dtype) + encoded_weight = inputs[1] + transpose_b = True + if weight_dtype == "float32": + encoded_weight = relay.const(encoded_weight.data.numpy().transpose(), weight_dtype) + transpose_b = False - new_weight = te.placeholder((weight_data.shape), dtype=weight_dtype) + new_weight = te.placeholder((encoded_weight.data.shape), dtype=weight_dtype) new_workload = autotvm.task.args_to_workload( - [tinfos[0], new_weight, None, out_type.dtype], topi_impl + [tinfos[0], new_weight, None, out_type.dtype, False, transpose_b], topi_impl ) dispatch_ctx.update(target, new_workload, cfg) - return relay.nn.matmul( + return _make.matmul( inputs[0], encoded_weight, - units=attrs.units, - out_dtype=attrs.out_dtype, - transpose_a=False, - transpose_b=False, + attrs.units, + attrs.out_dtype, + False, + transpose_b, ) # x86 schedules are used as a fallback diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index ea8b27cabcf6..439bed2361b1 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -29,41 +29,85 @@ @autotvm.register_topi_compute("matmul.arm_cpu.sme") -def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, transpose_b=False): +def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, transpose_b=True): """ SME Matmul compute definition. """ - assert ( - transpose_a == transpose_b == False - ), "Compute definition currently does not support transposed inputs." + assert transpose_a is False, "Transposed lhs not currently supported." + if data_b.dtype == "float16": + assert transpose_b is True, "Rhs must be transposed when dtype is float16." M, K = get_const_tuple(data_a.shape) - N = get_const_tuple(data_b.shape)[1] + if transpose_b: + N = get_const_tuple(data_b.shape)[0] + else: + N = get_const_tuple(data_b.shape)[1] if not out_dtype: out_dtype = data_a.dtype - tile_m = 2 * 4 * tvm.tir.vscale() - tile_n = 2 * 4 * tvm.tir.vscale() + tile_m = 2 * tvm.tir.get_vscale_factor(data_a.dtype) + tile_k = tvm.tir.get_vscale_factor(data_a.dtype) + if data_a.dtype == "float32": + tile_k *= 2 + tile_n = 2 * tvm.tir.get_vscale_factor(data_a.dtype) M_padded, pad_M = pad_dim_to_multiple(M, tile_m) + _, pad_K = pad_dim_to_multiple(K, tile_k) N_padded, pad_N = pad_dim_to_multiple(N, tile_n) + + m_pad_after = (pad_M, pad_K) + n_pad_after = (pad_K, pad_N) + if transpose_b: + n_pad_after = (pad_N, pad_K) + if pad_M != 0: - data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=(pad_M, 0)) + data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after) if pad_N != 0: - data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=(0, pad_N)) + data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after) + + if out_dtype is None: + out_dtype = data_a.dtype k = te.reduce_axis((0, K), name="k") + + def compute(*indices): + i, j = indices[-2:] + a_indices = (k, i) if transpose_a else (i, k) + b_indices = (j, k) if transpose_b else (k, j) + return te.sum( + data_a[a_indices].astype(out_dtype) * data_b[b_indices].astype(out_dtype), axis=k + ) + + compute_name = { + (True, True): "T_matmul_TT", + (True, False): "T_matmul_TN", + (False, True): "T_matmul_NT", + (False, False): "T_matmul_NN", + }[(transpose_a, transpose_b)] + C = te.compute( (M_padded, N_padded), - lambda m, n: te.sum( - data_a[m, k].astype(data_a.dtype) * data_b[k, n].astype(data_b.dtype), - axis=k, - ).astype(out_dtype), - name="matmul_sme_gemm", + compute, + name=compute_name, + attrs={"schedule_type": "sme"}, + ) + return te.compute((M, N), lambda m, n: C[m, n]) + + +def _get_transpose_interleave_intrin_name(in_dtype, out_dtype): + # pylint: disable=import-outside-toplevel + from tvm.tir.tensor_intrin.arm_cpu import ( + ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, + ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE, ) - C = te.compute((M, N), lambda m, n: C[m, n]) - return C + + if in_dtype == "float32" and out_dtype == "float32": + return ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE + elif in_dtype == "float16" and out_dtype == "float32": + return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE + else: + raise ValueError("Input/output data type combination not supported.") def tir_schedule_matmul_sme(sch): @@ -72,21 +116,37 @@ def tir_schedule_matmul_sme(sch): """ # pylint: disable=import-outside-toplevel from tvm.tir.tensor_intrin.arm_cpu import ( - ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, ARM_SME_INIT, get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, ) - gemm_block = sch.get_block("matmul_sme_gemm") + main_func = sch.mod["main"] + data_handle = main_func.params[0] + in_dtype = main_func.buffer_map[data_handle].dtype + out_dtype = "float32" + + root_block = sch.get_block(main_func.body.block.name_hint) + gemm_block = sch.get_child_blocks(root_block)[-2] + + gemm_block_name = sch.get(gemm_block).name_hint + transpose = gemm_block_name.split("_")[-1] + transpose_b = transpose[1] == "T" + m, n, k = sch.get_loops(gemm_block) extent_m = sch.get(m).extent extent_k = sch.get(k).extent + extent_n = sch.get(n).extent - tile_m = T.cast(2 * 4 * T.vscale(), extent_m.dtype) - tile_k = T.cast(2 * 4 * T.vscale(), extent_k.dtype) - tile_n = T.cast(2 * 4 * T.vscale(), sch.get(n).extent.dtype) + if in_dtype == "float16": + tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype) + tile_k = T.cast(tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype) + else: + tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype) + tile_k = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype) # Interleave the input utilizing the matrix tile interleave_a_block = sch.cache_read(gemm_block, 0, "global") @@ -95,9 +155,23 @@ def tir_schedule_matmul_sme(sch): outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) outer_k, inner_k = sch.split(k, factors=(None, tile_k), disable_predication=True) sch.reorder(outer_k, outer_m, inner_k, inner_m) - sch.tensorize(inner_k, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) + + transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(in_dtype, out_dtype) + sch.tensorize(inner_k, transpose_interleave_intrin_name) + + # Interleave the weights utilizing the matrix tile + if transpose_b: + interleave_b_block = sch.cache_read(gemm_block, 1, "global") + sch.transform_layout(interleave_b_block, ("write", 0), lambda n, k: (k, n)) + n, k = sch.get_loops(interleave_b_block) + outer_k, inner_k = sch.split(k, factors=(None, tile_k), disable_predication=True) + outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True) + sch.reorder(outer_k, outer_n, inner_k, inner_n) + sch.tensorize(inner_k, transpose_interleave_intrin_name) # Split and reorder the loops of the GeMM for tensorization + tile_m = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_m.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_n.dtype) m, n, k = sch.get_loops(gemm_block) outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True) @@ -108,10 +182,12 @@ def tir_schedule_matmul_sme(sch): sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT) # Tensorize the GeMM update - sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_k}" + sme_gemm_interleaved_intrin_name = ( + ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_k}_{in_dtype}" + ) tvm.tir.TensorIntrin.register( sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k), + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k, in_dtype), override=True, ) sch.tensorize(inner_m, sme_gemm_interleaved_intrin_name) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index f73d96e7c916..6c6af5bfb544 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -501,7 +501,7 @@ def main(A: T.Buffer((5,), "int32")): @pytest.mark.skipif( llvm_version_major() < 16, reason="SME is not supported in earlier versions of LLVM" ) -@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_matmul_sme(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v9a,+sme" @@ -510,7 +510,9 @@ def check_correct_assembly(dtype): B = te.placeholder((32, 32), dtype=dtype, name="B") with tvm.target.Target(target): - C = tvm.topi.arm_cpu.matmul.compute_matmul_sme(A, B, None, dtype, False, False) + C = tvm.topi.arm_cpu.matmul.compute_matmul_sme( + A, B, None, "float32", False, dtype == "float16" + ) prim_func = te.create_prim_func([A, B, C]) sch = tvm.tir.Schedule(prim_func) diff --git a/tests/python/relay/aot/aprofile_aem.mk b/tests/python/relay/aot/aprofile_aem.mk index 54be216eb6dd..a8d4445e266e 100644 --- a/tests/python/relay/aot/aprofile_aem.mk +++ b/tests/python/relay/aot/aprofile_aem.mk @@ -72,6 +72,7 @@ run: $(build_dir)/aot_test_runner -C SVE.ScalableVectorExtension.has_sme=1 \ -C SVE.ScalableVectorExtension.has_sve2=1 \ -C SVE.ScalableVectorExtension.enable_at_reset=1 \ + -C cluster0.has_arm_v9-2=1 \ -C bp.secure_memory=false \ -C bp.terminal_0.start_telnet=0 \ -C bp.terminal_1.start_telnet=0 \ diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py b/tests/python/relay/strategy/arm_cpu/test_dense.py index b9384e532e7d..eff465b32ee5 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -102,21 +102,22 @@ class TestDense(BasicDenseTests): "data_shape,weight_shape", [ ((32, 32), (32, 32)), - ((2, 35), (6, 35)), ((3, 3), (68, 3)), + ((2, 35), (6, 35)), ((79, 65), (152, 65)), ], ) -@pytest.mark.parametrize("dtype", ["float32"]) -def test_sme_dense(data_shape, weight_shape, dtype): +@pytest.mark.parametrize("in_dtype", ["float32", "float16"]) +def test_sme_dense(data_shape, weight_shape, in_dtype): np.random.seed(0) + out_dtype = "float32" - input_data = np.random.uniform(size=data_shape).astype(dtype) - inp = relay.var("data", shape=data_shape, dtype=dtype) - weight_data = np.random.uniform(size=weight_shape).astype(dtype) - weight = relay.const(weight_data, dtype=dtype) + input_data = np.random.uniform(size=data_shape).astype(in_dtype) + inp = relay.var("data", shape=data_shape, dtype=in_dtype) + weight_data = np.random.uniform(size=weight_shape).astype(in_dtype) + weight = relay.const(weight_data, dtype=in_dtype) - dense = relay.nn.dense(inp, weight) + dense = relay.nn.dense(inp, weight, out_dtype=out_dtype) func = relay.Function(relay.analysis.free_vars(dense), dense) ir_mod = tvm.IRModule.from_expr(func) diff --git a/tests/python/relay/strategy/arm_cpu/test_matmul.py b/tests/python/relay/strategy/arm_cpu/test_matmul.py index 3b46c8019a65..6a9a50272c73 100644 --- a/tests/python/relay/strategy/arm_cpu/test_matmul.py +++ b/tests/python/relay/strategy/arm_cpu/test_matmul.py @@ -38,33 +38,40 @@ ) @tvm.testing.requires_aprofile_aem_fvp @pytest.mark.parametrize( - "data_shape,weight_shape,transpose_a,transpose_b", + "data_shape,weight_shape,transpose_a,transpose_b,in_dtype", [ - ((4, 63), (63, 10), False, False), - ((64, 32), (32, 32), False, True), - ((96, 64), (64, 32), False, False), - ((62, 3), (3, 3), False, False), - ((4, 5), (79, 5), False, True), - ((134, 36), (36, 111), False, False), - ((3, 10), (10, 72), False, False), + ((4, 63), (63, 10), False, False, "float32"), + ((64, 32), (32, 32), False, True, "float32"), + ((96, 64), (64, 32), False, False, "float32"), + ((62, 3), (3, 3), False, False, "float32"), + ((4, 5), (79, 5), False, True, "float32"), + ((134, 36), (36, 111), False, False, "float32"), + ((3, 10), (10, 72), False, False, "float32"), + ((4, 63), (10, 63), False, True, "float16"), + ((96, 64), (32, 64), False, True, "float16"), + ((62, 3), (3, 3), False, True, "float16"), + ((4, 5), (79, 5), False, True, "float16"), + ((134, 36), (111, 36), False, True, "float16"), # Tensorization does not work when the reduction axis has unit iters. # See https://github.com/apache/tvm/issues/16566 # ((5, 1), (1, 5), False, False), ], ) -@pytest.mark.parametrize("dtype", ["float32"]) -def test_sme_matmul_with_const_b(data_shape, weight_shape, transpose_a, transpose_b, dtype): +def test_sme_matmul_with_const_b(data_shape, weight_shape, transpose_a, transpose_b, in_dtype): """ Execution tests for matmul Scalable Matrix Extension (SME) schedule. """ np.random.seed(0) + out_dtype = "float32" - input_data = np.random.uniform(size=data_shape).astype(dtype) - inp = relay.var("data", shape=data_shape, dtype=dtype) - weight_data = np.random.uniform(size=weight_shape).astype(dtype) - weight = relay.const(weight_data, dtype=dtype) + input_data = np.random.uniform(size=data_shape).astype(in_dtype) + inp = relay.var("data", shape=data_shape, dtype=in_dtype) + weight_data = np.random.uniform(size=weight_shape).astype(in_dtype) + weight = relay.const(weight_data, dtype=in_dtype) - matmul = relay.nn.matmul(inp, weight, transpose_a=transpose_a, transpose_b=transpose_b) + matmul = relay.nn.matmul( + inp, weight, out_dtype=out_dtype, transpose_a=transpose_a, transpose_b=transpose_b + ) func = relay.Function(relay.analysis.free_vars(matmul), matmul) ir_mod = tvm.IRModule.from_expr(func) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index f74b31157ae2..eb57f795e238 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1455,7 +1455,7 @@ def expected(): @pytest.mark.skipif( llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" ) -def test_alter_op_dense_arm_cpu_sme(): +def test_alter_op_dense_arm_cpu_sme_float32(): np.random.seed(0) y_data = np.random.uniform(size=(64, 32)).astype("float32") @@ -1478,6 +1478,36 @@ def expected(): assert tvm.ir.structural_equal(a, b) +@pytest.mark.skipif( + llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" +) +def test_alter_op_dense_arm_cpu_sme_float16_float32(): + from tvm.relay.op.nn import _make # pylint: disable-top-level-import + + np.random.seed(0) + y_data = np.random.uniform(size=(64, 32)).astype("float16") + + def before(): + x = relay.var("x", shape=(32, 32), dtype="float16") + y = relay.const(y_data, dtype="float16") + dense = relay.nn.dense(x, y, out_dtype="float32") + return relay.Function(analysis.free_vars(dense), dense) + + def expected(): + x = relay.var("x", shape=(32, 32), dtype="float16") + y = relay.const(y_data, dtype="float16") + # Cannot make using the public API (relay.nn.matmul) since it will + # create an nn.dense op instead + matmul = _make.matmul(x, y, None, "float32", False, True) + return relay.Function(analysis.free_vars(matmul), matmul) + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme"): + with TempOpAttr("nn.dense", "FTVMAlterOpLayout", topi.arm_cpu.dense_alter_op._alter_dense): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + @pytest.mark.skipif( llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" ) From 1fe9baccbeab4b7d0a5250ee4b489585b40d06ac Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 15 May 2024 12:03:45 +0000 Subject: [PATCH 2/5] Fix failing asserts Change-Id: Ie7fb7a0a76119aa5c82e03ea0b2cc10de9f15f5e --- python/tvm/topi/arm_cpu/matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index 439bed2361b1..42db54137f08 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -33,9 +33,9 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, tra """ SME Matmul compute definition. """ - assert transpose_a is False, "Transposed lhs not currently supported." + assert bool(transpose_a) is False, "Transposed lhs not currently supported." if data_b.dtype == "float16": - assert transpose_b is True, "Rhs must be transposed when dtype is float16." + assert bool(transpose_b) is True, "Rhs must be transposed when dtype is float16." M, K = get_const_tuple(data_a.shape) if transpose_b: From 7d102686e92531a2c8d3829fd0982fa0014e55b0 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 15 May 2024 12:36:05 +0000 Subject: [PATCH 3/5] Change ptrue predicate to use boolean values Change-Id: I0e9e45b285082b42676e53e74158e11d7e08608b --- python/tvm/tir/tensor_intrin/arm_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 9ffdbc659729..97d8a304c981 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -173,7 +173,7 @@ def _create_ptrue_mask(dtype): """ Creates a mask that enables all lanes of a scalable vector. """ - return T.broadcast(T.IntImm("int1", 1), tir.get_vscale_factor(dtype)) + return T.broadcast(T.bool(True), tir.get_vscale_factor(dtype)) def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): From 7363127a00b45f362b9be6b5d8040d8f11753879 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 15 May 2024 15:44:17 +0000 Subject: [PATCH 4/5] Fix topi_matmul test and avoid scalable expression warnings Change-Id: I32273241ae7569b65e082759e4f2ca4355ac6933 --- .../relay/strategy/arm_cpu/test_dense.py | 2 +- .../relay/strategy/arm_cpu/test_matmul.py | 2 +- tests/python/topi/test_topi_matmul.py | 31 ++++++++++++++----- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py b/tests/python/relay/strategy/arm_cpu/test_dense.py index eff465b32ee5..0419d14201f0 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -139,7 +139,7 @@ def test_sme_dense(data_shape, weight_shape, in_dtype): with tvm.transform.PassContext( opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config - ), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): + ), target, meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): executor_factory = tvm.relay.build( ir_mod, target=target, diff --git a/tests/python/relay/strategy/arm_cpu/test_matmul.py b/tests/python/relay/strategy/arm_cpu/test_matmul.py index 6a9a50272c73..83f9ac1da5ba 100644 --- a/tests/python/relay/strategy/arm_cpu/test_matmul.py +++ b/tests/python/relay/strategy/arm_cpu/test_matmul.py @@ -92,7 +92,7 @@ def test_sme_matmul_with_const_b(data_shape, weight_shape, transpose_a, transpos ) with tvm.transform.PassContext( opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config - ), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): + ), target, meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): executor_factory = tvm.relay.build( ir_mod, target=target, diff --git a/tests/python/topi/test_topi_matmul.py b/tests/python/topi/test_topi_matmul.py index a7b3965aeed3..d4abcd49d0ee 100644 --- a/tests/python/topi/test_topi_matmul.py +++ b/tests/python/topi/test_topi_matmul.py @@ -152,15 +152,30 @@ def test_tensordot(): verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1))) -@pytest.mark.parametrize("transpose_a,transpose_b", [(True, False), (False, True)]) -def test_unsupported_sme_matmul_compute_transpose(transpose_a, transpose_b): - """ - SME matmul compute does not support transposed inputs for now. - """ - err_msg = "Compute definition currently does not support transposed inputs." - with pytest.raises(AssertionError, match=err_msg) as e: +@pytest.mark.parametrize("in_dtype", ["float32", "float16"]) +def test_unsupported_sme_matmul_compute_transpose_a(in_dtype): + err_msg = "Transposed lhs not currently supported." + with pytest.raises(AssertionError, match=err_msg): + compute_matmul_sme( + te.placeholder((32, 32), dtype=in_dtype), + te.placeholder((32, 32), dtype=in_dtype), + None, + None, + True, + False, + ) + + +def test_unsupported_sme_matmul_compute_transpose_b(): + err_msg = "Rhs must be transposed when dtype is float16." + with pytest.raises(AssertionError, match=err_msg): compute_matmul_sme( - te.placeholder((32, 32)), te.placeholder((32, 32)), None, None, transpose_a, transpose_b + te.placeholder((32, 32), dtype="float16"), + te.placeholder((32, 32), dtype="float16"), + None, + None, + False, + False, ) From bc02e4758a5c0c5f5d27af7e2bf4a732993db268 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 22 May 2024 14:23:03 +0000 Subject: [PATCH 5/5] Address comments Change-Id: I237b4c5cb5ca22e33529d98cbd75177b94904857 --- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 6 ++--- python/tvm/tir/tensor_intrin/arm_cpu.py | 10 ++++----- python/tvm/topi/arm_cpu/dense_alter_op.py | 8 +++++++ python/tvm/topi/arm_cpu/matmul.py | 22 +++++++++---------- .../relay/strategy/arm_cpu/test_dense.py | 2 +- 6 files changed, 29 insertions(+), 21 deletions(-) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 4ecac98cde20..0fee976eb130 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,7 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic -from .op import vscale, get_active_lane_mask, get_vscale_factor +from .op import vscale, get_active_lane_mask, get_vscale_expr from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index c086bebafa4c..95a85ab77d36 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3370,16 +3370,16 @@ def get_active_lane_mask(dtype, base, limit): return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) -def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr: +def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr: """ Create a datatype dependent scalable expression. Parameters ---------- - dtype : tvm.DataType + dtype : Union[str, tvm.DataType] Element data type. min_size : int - The minimum size of the scalable vector. + The minimum size of the scalable vector in bits. """ if isinstance(dtype, str): dtype = tvm.DataType(dtype) diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 97d8a304c981..3a3430af514f 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -173,7 +173,7 @@ def _create_ptrue_mask(dtype): """ Creates a mask that enables all lanes of a scalable vector. """ - return T.broadcast(T.bool(True), tir.get_vscale_factor(dtype)) + return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype)) def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): @@ -213,7 +213,7 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = tir.get_vscale_factor("float32") + SVF = tir.get_vscale_expr("float32") SVF2 = 2 * SVF @T.prim_func @@ -347,7 +347,7 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin(): """ # pylint: enable=line-too-long - SVF = tir.get_vscale_factor("float16") + SVF = tir.get_vscale_expr("float16") SVF2 = 2 * SVF @T.prim_func @@ -532,7 +532,7 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = tir.get_vscale_factor("float32") + SVF = tir.get_vscale_expr("float32") SVF2 = 2 * SVF fmopa_intrin = ( "llvm.aarch64.sme.mopa" if in_dtype == "float32" else "llvm.aarch64.sme.mopa.wide" @@ -577,7 +577,7 @@ def impl(): rows_per_iter = 1 if in_dtype == "float32" else 2 with T.serial(T.ceildiv(K, rows_per_iter)) as k: k_row = k * rows_per_iter - in_dtype_svf = tir.get_vscale_factor(in_dtype) + in_dtype_svf = tir.get_vscale_expr(in_dtype) a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]) b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]) diff --git a/python/tvm/topi/arm_cpu/dense_alter_op.py b/python/tvm/topi/arm_cpu/dense_alter_op.py index 398f8398af1c..0ad878b7412e 100644 --- a/python/tvm/topi/arm_cpu/dense_alter_op.py +++ b/python/tvm/topi/arm_cpu/dense_alter_op.py @@ -55,7 +55,15 @@ def _alter_dense(attrs, inputs, tinfos, out_type): weight_dtype = tinfos[1].dtype encoded_weight = inputs[1] + + # For dense the weights (rhs) are provided in transposed format, + # i.e. they are of the shape (n, k). transpose_b = True + + # The SME schedule expects the rhs to be in the format (k, n). We can do this + # transformation at compile time in the case of float32. Note: For the + # float16->float32 schedule the transformation currently happens at runtime + # with the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic. if weight_dtype == "float32": encoded_weight = relay.const(encoded_weight.data.numpy().transpose(), weight_dtype) transpose_b = False diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index 42db54137f08..2f09e24c87a2 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -46,11 +46,11 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, tra if not out_dtype: out_dtype = data_a.dtype - tile_m = 2 * tvm.tir.get_vscale_factor(data_a.dtype) - tile_k = tvm.tir.get_vscale_factor(data_a.dtype) + tile_m = 2 * tvm.tir.get_vscale_expr(data_a.dtype) + tile_k = tvm.tir.get_vscale_expr(data_a.dtype) if data_a.dtype == "float32": tile_k *= 2 - tile_n = 2 * tvm.tir.get_vscale_factor(data_a.dtype) + tile_n = 2 * tvm.tir.get_vscale_expr(data_a.dtype) M_padded, pad_M = pad_dim_to_multiple(M, tile_m) _, pad_K = pad_dim_to_multiple(K, tile_k) @@ -140,13 +140,13 @@ def tir_schedule_matmul_sme(sch): extent_n = sch.get(n).extent if in_dtype == "float16": - tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype) - tile_k = T.cast(tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype) - tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype) + tile_m = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_m.dtype) + tile_k = T.cast(tvm.tir.get_vscale_expr(in_dtype), extent_k.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_n.dtype) else: - tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype) - tile_k = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype) - tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype) + tile_m = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_m.dtype) + tile_k = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_k.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_n.dtype) # Interleave the input utilizing the matrix tile interleave_a_block = sch.cache_read(gemm_block, 0, "global") @@ -170,8 +170,8 @@ def tir_schedule_matmul_sme(sch): sch.tensorize(inner_k, transpose_interleave_intrin_name) # Split and reorder the loops of the GeMM for tensorization - tile_m = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_m.dtype) - tile_n = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_n.dtype) + tile_m = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_m.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_n.dtype) m, n, k = sch.get_loops(gemm_block) outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True) diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py b/tests/python/relay/strategy/arm_cpu/test_dense.py index 0419d14201f0..3a8427e8154d 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -102,8 +102,8 @@ class TestDense(BasicDenseTests): "data_shape,weight_shape", [ ((32, 32), (32, 32)), - ((3, 3), (68, 3)), ((2, 35), (6, 35)), + ((3, 3), (68, 3)), ((79, 65), (152, 65)), ], )