diff --git a/aie_kernels/aie2/mm.cc b/aie_kernels/aie2/mm.cc index 0444fa6018..e78bab49b3 100644 --- a/aie_kernels/aie2/mm.cc +++ b/aie_kernels/aie2/mm.cc @@ -366,6 +366,23 @@ void matmul_vectorized_4x4x4_i16_i16(const int16 *__restrict pA, pC); } +template +void matmul_vectorized_4x4x4_i16_i32(const int16 *__restrict pA, + const int16 *__restrict pB, + int32 *__restrict pC) { + // matmul_vectorized operates on two 4x4 input blocks of A, and two 4x4 input + // blocks of B in each iteration. Make sure we have at least 2 blocks in each + // dimension, and that our input matrix is evenly divisible. + constexpr int r = 4; + constexpr int s = 4; + constexpr int t = 4; + static_assert(m % (2 * r) == 0 && m / (2 * r) > 0); + static_assert(k % (2 * s) == 0 && k / (2 * s) > 0); + static_assert(n % (2 * t) == 0 && n / (2 * t) > 0); + return matmul_vectorized(pA, pB, + pC); +} + template void matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA, const bfloat16 *__restrict pB, @@ -416,6 +433,7 @@ extern "C" { #define combos(X) \ X(int16, i16, int16, i16, 4, 4, 4) \ + X(int16, i16, int32, i32, 4, 4, 4) \ X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) \ X(bfloat16, bf16, float, f32, 4, 8, 4) diff --git a/programming_examples/basic/matrix_multiplication/README.md b/programming_examples/basic/matrix_multiplication/README.md index 88b701ffa2..7b001744b5 100644 --- a/programming_examples/basic/matrix_multiplication/README.md +++ b/programming_examples/basic/matrix_multiplication/README.md @@ -16,4 +16,22 @@ Subdirectories in this directory contain example designs that implement matrix m * [`single_core`](single_core) - This design performs matrix-matrix multiplication on a single AI Engine core. * [`whole_array`](whole_array) - This design evolves `single_core`, by splitting the computation and parallelizing it. It utilizes all available AI Engine cores simultaneously. -* [`matrix_vector`](matrix_vector) - This design is a specialization to the matrix-vector-multiplication case, which poses unique challenges due to lower computation density. *Work in progress.* \ No newline at end of file +* [`matrix_vector`](matrix_vector) - This design is a specialization to the matrix-vector-multiplication case, which poses unique challenges due to lower computation density. *Work in progress.* + +## Note on Numerical Tolerances + +This directory contains verification code that ensures the designs in the subdirectories produce the correct output. + +The designs can be configured to work on different input and output data types, based on the Makefile variables `dtype_in` and `dtype_out`. +In the default configuration, all designs consume integer intputs and produce integer outputs. +For this case, the verification checks for strict equivalence between the reference output computed on the host CPU and the output calculated on the AI Engine. +That is, verification will only pass for integer data types if the output is equivalent bit-by-bit. + +For floating point data types, the verification code allows the AI Engine output to deviate from the reference calculated on the host CPU by some limited maximal relative and absolute tolerance (defined in `common.h`). +This standard practice is necessary for the following reasons: + + - Operations on IEEE 754 floating point values are not commutative. That is, the order of operations can affect the results. All designs in the subdirectories perform tiling of the input matrices, multiplying and accumulating sub-matrices in chunks. The reference calculation code on the CPU, on the other hand, does not perform tiling. As such, some differences due to non-commutativity are expected. + - The reference on the host CPU is always computed in `float32`, even if the input data type is `bfloat16`, since the host CPU does not support native `bfloat16` multiplication. This means results are calculated with higher precision on the CPU and subsequently truncated, whereas the AI Engine is able to calculate results in a more performant manner thanks to natively using the lower precision data type. + - If the output datatype is lower-precision than the accumulation data type, the tiling in the `K` dimension affects the results. For example, when multiplying `bfloat16` numbers, the AI Engine accumulates results in higher-precision `float32`. Our designs perform such accumulation for `k` (tiling size in `K` dimension) times before writing the results back into the output buffer. If the output buffer is lower-precision, results are truncated at that time. A larger `k` dimension means fewer such truncations take place. The AI Engine also provides a higher-precision "cascade" data path, which can be used to accumulate results between cores, although none of the designs in this directory make use of this currently. + +In summary, different choices of data types, tiling strategies, and usage of AI Engine components, can all affect floating point results in slight ways. Deciding on different choices for these factors presents interesting trade-offs that must be considered on a case-by-case basis for the application at hand. diff --git a/programming_examples/basic/matrix_multiplication/common.h b/programming_examples/basic/matrix_multiplication/common.h index b2c6c14b53..cba6ff6363 100644 --- a/programming_examples/basic/matrix_multiplication/common.h +++ b/programming_examples/basic/matrix_multiplication/common.h @@ -109,11 +109,16 @@ std::vector load_instr_sequence(std::string instr_path) { // Matrix / Float / Math // -------------------------------------------------------------------------- -static inline std::int16_t random_int16_t() { +template +static inline T get_random(); + +template <> +std::int16_t get_random() { return (std::int16_t)rand() % 0x10000; } -static inline std::bfloat16_t random_bfloat16_t() { +template <> +std::bfloat16_t get_random() { // Random numbers should NOT be uniformly between 0 and 1, because that // would make the matrix product AB always close to 1. return std::bfloat16_t(4.0 * (float)rand() / (float)(RAND_MAX)); @@ -165,6 +170,51 @@ bool nearly_equal(float a, float b, float epsilon = 128 * FLT_EPSILON, return diff < std::max(abs_th, epsilon * norm); } +template +static inline float get_abs_tol(); +template +static inline float get_rel_tol(); + +template <> +float get_abs_tol() { + return 0.0; +} + +template <> +float get_abs_tol() { + return 0.0; +} + +template <> +float get_abs_tol() { + return 0.5; +} + +template <> +float get_abs_tol() { + return 0.5; +} + +template <> +float get_rel_tol() { + return 0.0; +} + +template <> +float get_rel_tol() { + return 0.0; +} + +template <> +float get_rel_tol() { + return 0.05; +} + +template <> +float get_rel_tol() { + return 0.05; +} + template void print_matrix(const std::vector matrix, int n_cols, int n_printable_rows = 10, int n_printable_cols = 10, @@ -237,10 +287,14 @@ struct error { template std::optional> -verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual) { - const float absTol = 0.5; - const float relTol = 0.05; - if (!nearly_equal(expected, actual, relTol, absTol)) { +verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual, + float abs_tol, float rel_tol) { + bool match = expected == actual; + if (abs_tol > 0 || rel_tol > 0) { + // Allow for some tolerance for float data types + match = nearly_equal(expected, actual, rel_tol, abs_tol); + } + if (!match) { return (struct error){row, col, expected, actual}; } return std::nullopt; @@ -275,7 +329,8 @@ void print_progress_bar(std::ostream &os, double progress, int len = 75) { template int verify(int M, int N, int K, std::vector A, std::vector B, - std::vector C, int verbosity = 0) { + std::vector C, int verbosity = 0, float abs_tol = 0.5, + float rel_tol = 0.05) { int n_errors = 0; std::vector> errors; Tout max_rel_error = (Tout)0.0f; @@ -285,8 +340,9 @@ int verify(int M, int N, int K, std::vector A, std::vector B, for (int row = 0; row < M; row++) { for (int col = 0; col < N; col++) { - std::optional> error = verify_single( - std::cout, row, col, CRef[row * N + col], C[row * N + col]); + std::optional> error = + verify_single(std::cout, row, col, CRef[row * N + col], + C[row * N + col], abs_tol, rel_tol); if (error.has_value()) { if (n_errors < max_printable_errors) { errors.push_back(*error); @@ -316,7 +372,8 @@ int verify(int M, int N, int K, std::vector A, std::vector B, template int verify_stochastic(int M, int N, int K, std::vector A, std::vector B, std::vector C, int n_samples, - int verbosity = 0) { + int verbosity = 0, float abs_tol = 0.5, + float rel_tol = 0.05) { std::mt19937 rng; auto rows = std::views::iota(0, M); auto cols = std::views::iota(0, N); @@ -342,8 +399,8 @@ int verify_stochastic(int M, int N, int K, std::vector A, print_progress_bar(std::cerr, progress); } Tout ref = mul_acc(M, N, K, row, col, A, B); - std::optional> error = - verify_single(std::cout, row, col, ref, C[row * N + col]); + std::optional> error = verify_single( + std::cout, row, col, ref, C[row * N + col], abs_tol, rel_tol); if (error.has_value()) { if (n_errors < max_printable_errors) { errors.push_back(*error); diff --git a/programming_examples/basic/matrix_multiplication/makefile-common b/programming_examples/basic/matrix_multiplication/makefile-common index 9f336f1099..ba21462442 100644 --- a/programming_examples/basic/matrix_multiplication/makefile-common +++ b/programming_examples/basic/matrix_multiplication/makefile-common @@ -37,6 +37,32 @@ include ${current_dir}../../makefile-common M?=512 K?=512 N?=512 +dtype_in?=i16 +dtype_out?=i32 + +ifeq ($(dtype_in),bf16) + dtype_in_cpp=std::bfloat16_t +endif +ifeq ($(dtype_out),bf16) + dtype_out_cpp=std::bfloat16_t + dtype_acc_cpp=float +endif + +ifeq ($(dtype_in),i16) + dtype_in_cpp=int16_t +endif +ifeq ($(dtype_out),i16) + dtype_out_cpp=int16_t + dtype_acc_cpp=int16_t +endif +ifeq ($(dtype_out),i32) + dtype_out_cpp=int32_t + dtype_acc_cpp=int32_t +endif +ifeq ($(dtype_out),f32) + dtype_out_cpp=float + dtype_acc_cpp=float +endif trace_size?=65536 @@ -46,7 +72,7 @@ xclbin_target?=build/final_${target_suffix}.xclbin insts_target?=build/insts_${target_suffix}.txt runargs?=-v 2 --warmup 1 --iters 1 -aieargs+=-M $M -K $K -N $N +aieargs+=-M $M -K $K -N $N --dtype_in ${dtype_in} --dtype_out ${dtype_out} kernels_dir=${srcdir}/../../../../aie_kernels/aie2 @@ -69,7 +95,8 @@ ${xclbin_target}: ${mlir_target} ${kernels:%=build/%.o} ${targetname}.exe: ${srcdir}/test.cpp ${srcdir}/../test.cpp ${srcdir}/../common.h rm -rf _build mkdir -p _build - cd _build && ${powershell} cmake -E env CXXFLAGS="-std=c++23 -ggdb" cmake ${srcdir}/.. -D CMAKE_C_COMPILER=gcc-13 -D CMAKE_CXX_COMPILER=g++-13 -DTARGET_NAME=${targetname} -Dsubdir=${subdir} + cd _build && ${powershell} cmake -E env CXXFLAGS="-std=c++23 -ggdb -DDTYPE_IN=${dtype_in_cpp} -DDTYPE_OUT=${dtype_out_cpp} -DDTYPE_ACC=${dtype_acc_cpp}" \ + cmake ${srcdir}/.. -D CMAKE_C_COMPILER=gcc-13 -D CMAKE_CXX_COMPILER=g++-13 -DTARGET_NAME=${targetname} -Dsubdir=${subdir} cd _build && ${powershell} cmake --build . --config Release ifeq "${powershell}" "powershell.exe" cp _build/${targetname}.exe $@ diff --git a/programming_examples/basic/matrix_multiplication/single_core/Makefile b/programming_examples/basic/matrix_multiplication/single_core/Makefile index a1da00108f..3fcab3f24d 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/Makefile +++ b/programming_examples/basic/matrix_multiplication/single_core/Makefile @@ -18,7 +18,7 @@ K?=256 N?=256 m?=64 k?=64 -n?=64 +n?=32 kernels=mm_${m}x${k}x${n} aieargs+=-m $m -k $k -n $n diff --git a/programming_examples/basic/matrix_multiplication/single_core/aie2.py b/programming_examples/basic/matrix_multiplication/single_core/aie2.py index 5eef847850..a6ee2e8198 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/aie2.py +++ b/programming_examples/basic/matrix_multiplication/single_core/aie2.py @@ -25,20 +25,37 @@ def main(): argparser.add_argument("-N", type=int, default=256) argparser.add_argument("-m", type=int, default=64) argparser.add_argument("-k", type=int, default=64) - argparser.add_argument("-n", type=int, default=64) + argparser.add_argument("-n", type=int, default=32) + argparser.add_argument( + "--dtype_in", type=str, choices=["bf16", "i16"], default="i16" + ) + argparser.add_argument( + "--dtype_out", type=str, choices=["bf16", "i16", "f32", "i32"], default="i32" + ) args = argparser.parse_args() - my_matmul(args.M, args.K, args.N, args.m, args.k, args.n) + my_matmul( + args.M, args.K, args.N, args.m, args.k, args.n, args.dtype_in, args.dtype_out + ) + + +def ceildiv(a, b): + return (a + b - 1) // b -def my_matmul(M, K, N, m, k, n): +def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str): assert M % m == 0 assert K % k == 0 assert N % n == 0 - r = 4 - s = 8 - t = 4 + if dtype_in_str == "bf16": + r = 4 + s = 8 + t = 4 + elif dtype_in_str == "i16": + r = 4 + s = 4 + t = 4 assert m % r == 0 assert k % s == 0 @@ -48,10 +65,24 @@ def my_matmul(M, K, N, m, k, n): enable_tracing = False trace_size = 65536 + dtype_in = None + if dtype_in_str == "bf16": + dtype_in = T.bf16 + elif dtype_in_str == "i16": + dtype_in = T.i16 + dtype_out = None + if dtype_out_str == "bf16": + dtype_out = T.bf16 + elif dtype_out_str == "i16": + dtype_out = T.i16 + elif dtype_out_str == "f32": + dtype_out = T.f32 + elif dtype_out_str == "i32": + dtype_out = T.i32 + A_sz = M * K B_sz = K * N C_sz = M * N - C_sz_in_bytes = C_sz * 2 M_div_m = M // m K_div_k = K // k @@ -66,25 +97,30 @@ def my_matmul(M, K, N, m, k, n): with mlir_mod_ctx() as ctx: + C_sz_in_bytes = C_sz * dtype_out().width // 8 + @device(AIEDevice.npu1_1col) def device_body(): - memref_a_ty = T.memref(m, k, T.bf16()) - memref_b_ty = T.memref(k, n, T.bf16()) - memref_c_ty = T.memref(m, n, T.bf16()) + memref_a_ty = T.memref(m, k, dtype_in()) + memref_b_ty = T.memref(k, n, dtype_in()) + memref_c_ty = T.memref(m, n, dtype_out()) ofifo_memref_a_ty = TypeAttr.get(ObjectFifoType.get(memref_a_ty)) ofifo_memref_b_ty = TypeAttr.get(ObjectFifoType.get(memref_b_ty)) ofifo_memref_c_ty = TypeAttr.get(ObjectFifoType.get(memref_c_ty)) # AIE Core Function declarations - zero_scalar = external_func("zero_scalar_bf16", inputs=[memref_c_ty]) - zero = external_func("zero_bf16", inputs=[memref_c_ty]) + zero_scalar = external_func( + f"zero_scalar_{dtype_out_str}", inputs=[memref_c_ty] + ) + zero = external_func(f"zero_{dtype_out_str}", inputs=[memref_c_ty]) matmul_scalar = external_func( - "matmul_scalar_bf16_bf16", + f"matmul_scalar_{dtype_in_str}_{dtype_out_str}", inputs=[memref_a_ty, memref_b_ty, memref_c_ty], ) matmul = external_func( - "matmul_bf16_bf16", inputs=[memref_a_ty, memref_b_ty, memref_c_ty] + f"matmul_{dtype_in_str}_{dtype_out_str}", + inputs=[memref_a_ty, memref_b_ty, memref_c_ty], ) # Tile declarations @@ -196,9 +232,9 @@ def core_body(): # To/from AIE-array data movement @FuncOp.from_py_func( - T.memref(A_sz, T.bf16()), - T.memref(B_sz, T.bf16()), - T.memref(C_sz, T.bf16()), + T.memref(A_sz, dtype_in()), + T.memref(B_sz, dtype_in()), + T.memref(C_sz, dtype_out()), ) def sequence(A, B, C): @@ -213,9 +249,7 @@ def sequence(A, B, C): # only do 5 tile rows at a time before synchronizing, so we can reuse BDs rows_per_block = 5 - for tile_row_block in range( - (M_div_m + rows_per_block - 1) // rows_per_block - ): + for tile_row_block in range(ceildiv(M_div_m, rows_per_block)): C_row_offset = tile_row_block * rows_per_block * m * N num_tile_rows = min( [rows_per_block, M_div_m - tile_row_block * rows_per_block] diff --git a/programming_examples/basic/matrix_multiplication/test.cpp b/programming_examples/basic/matrix_multiplication/test.cpp index c838f30aeb..378f81a407 100644 --- a/programming_examples/basic/matrix_multiplication/test.cpp +++ b/programming_examples/basic/matrix_multiplication/test.cpp @@ -28,15 +28,32 @@ #ifndef DATATYPES_USING_DEFINED #define DATATYPES_USING_DEFINED -using A_DATATYPE = std::bfloat16_t; -using B_DATATYPE = std::bfloat16_t; -using C_DATATYPE = std::bfloat16_t; -using ACC_DATATYPE = float; +#ifndef DTYPE_IN +#define DTYPE_IN std::bfloat16_t +#endif +#ifndef DTYPE_OUT +#define DTYPE_OUT std::bfloat16_t +#endif +#ifndef DTYPE_ACC +#define DTYPE_ACC float +#endif +using A_DATATYPE = DTYPE_IN; +using B_DATATYPE = DTYPE_IN; +using C_DATATYPE = DTYPE_OUT; +using ACC_DATATYPE = DTYPE_ACC; #endif +#define XSTR(X) STR(X) +#define STR(X) #X + constexpr long long verify_stochastic_threshold = 1024 * 1024 * 1024; constexpr int verify_stochastic_n_samples = 1000; +// Verification tolerance +// See "Note on Numerical Tolerances" in README.md +float abs_tol = matmul_common::get_abs_tol(); +float rel_tol = matmul_common::get_rel_tol(); + namespace po = boost::program_options; int main(int argc, const char *argv[]) { @@ -139,14 +156,14 @@ int main(int argc, const char *argv[]) { A_DATATYPE *bufA = bo_a.map(); std::vector AVec(A_VOLUME); for (int i = 0; i < A_VOLUME; i++) { - AVec[i] = matmul_common::random_bfloat16_t(); + AVec[i] = matmul_common::get_random(); // AVec[i] = i; } memcpy(bufA, AVec.data(), (AVec.size() * sizeof(A_DATATYPE))); B_DATATYPE *bufB = bo_b.map(); std::vector BVec(B_VOLUME); for (int i = 0; i < B_VOLUME; i++) { - BVec[i] = matmul_common::random_bfloat16_t(); + BVec[i] = matmul_common::get_random(); // Diagonal: // if(i % N == i / N) { // BVec[i] = 1.0; @@ -162,6 +179,10 @@ int main(int argc, const char *argv[]) { memset(bufOut, 0, OUT_SIZE); if (verbosity >= 2) { + std::cout << "DTYPE_IN = " XSTR(DTYPE_IN) "\n"; + std::cout << "DTYPE_OUT = " XSTR(DTYPE_OUT) "\n"; + std::cout << "Verification tolerance " << abs_tol << " absolute, " + << rel_tol << " relative.\n"; std::cout << "A = \n"; matmul_common::print_matrix(AVec, K); std::cout << "B = \n"; @@ -221,10 +242,11 @@ int main(int argc, const char *argv[]) { if (do_verify_stochastic) { errors = matmul_common::verify_stochastic( - M, N, K, AVec, BVec, CVec, verify_stochastic_n_samples, verbosity); + M, N, K, AVec, BVec, CVec, verify_stochastic_n_samples, verbosity, + abs_tol, rel_tol); } else { errors = matmul_common::verify( - M, N, K, AVec, BVec, CVec); + M, N, K, AVec, BVec, CVec, abs_tol, rel_tol); } auto vstop = std::chrono::system_clock::now(); float vtime = diff --git a/programming_examples/basic/matrix_multiplication/whole_array/Makefile b/programming_examples/basic/matrix_multiplication/whole_array/Makefile index 31ee48950d..127606f721 100644 --- a/programming_examples/basic/matrix_multiplication/whole_array/Makefile +++ b/programming_examples/basic/matrix_multiplication/whole_array/Makefile @@ -15,7 +15,7 @@ M?=640 K?=896 N?=768 m?=16 -k?=64 +k?=32 n?=48 n_aie_cols?=2 diff --git a/programming_examples/basic/matrix_multiplication/whole_array/aie2.py b/programming_examples/basic/matrix_multiplication/whole_array/aie2.py index 0bd8d119fb..9aaad9f252 100644 --- a/programming_examples/basic/matrix_multiplication/whole_array/aie2.py +++ b/programming_examples/basic/matrix_multiplication/whole_array/aie2.py @@ -25,11 +25,27 @@ def main(): argparser.add_argument("-N", type=int, default=512) argparser.add_argument("-m", type=int, default=64) argparser.add_argument("-k", type=int, default=64) - argparser.add_argument("-n", type=int, default=64) + argparser.add_argument("-n", type=int, default=32) argparser.add_argument("--n-aie-cols", type=int, choices=[1, 2, 4], default=4) + argparser.add_argument( + "--dtype_in", type=str, choices=["bf16", "i16"], default="i16" + ) + argparser.add_argument( + "--dtype_out", type=str, choices=["bf16", "i16", "f32", "i32"], default="i16" + ) args = argparser.parse_args() with mlir_mod_ctx() as ctx: - my_matmul(args.M, args.K, args.N, args.m, args.k, args.n, args.n_aie_cols) + my_matmul( + args.M, + args.K, + args.N, + args.m, + args.k, + args.n, + args.n_aie_cols, + args.dtype_in, + args.dtype_out, + ) # print(ctx.module.operation.verify()) print(ctx.module) @@ -38,14 +54,35 @@ def ceildiv(a, b): return (a + b - 1) // b -def my_matmul(M, K, N, m, k, n, n_aie_cols): - r = 4 - s = 8 - t = 4 +def my_matmul(M, K, N, m, k, n, n_aie_cols, dtype_in_str, dtype_out_str): n_aie_rows = 4 n_aie_cores = n_aie_rows * n_aie_cols + dtype_in = None + if dtype_in_str == "bf16": + dtype_in = T.bf16 + elif dtype_in_str == "i16": + dtype_in = T.i16 + dtype_out = None + if dtype_out_str == "bf16": + dtype_out = T.bf16 + elif dtype_out_str == "i16": + dtype_out = T.i16 + elif dtype_out_str == "f32": + dtype_out = T.f32 + elif dtype_out_str == "i32": + dtype_out = T.i32 + + if dtype_in_str == "bf16": + r = 4 + s = 8 + t = 4 + elif dtype_in_str == "i16": + r = 4 + s = 4 + t = 4 + # Input matrix A: # Conceptually, we divide input A into (m * n_rows, k)-sized blocks. These # blocks are _broadcast_ across AIE core columns, then _distributed_ across @@ -90,22 +127,30 @@ def my_matmul(M, K, N, m, k, n, n_aie_cols): @device(dev) def device_body(): - A_l2_memref_ty = T.memref(m * k * n_A_tiles_per_shim, T.bf16()) - B_l2_memref_ty = T.memref(k * n, T.bf16()) - C_l2_memref_ty = T.memref(m * n * n_aie_rows, T.bf16()) - A_l1_memref_ty = T.memref(m, k, T.bf16()) - B_l1_memref_ty = T.memref(k, n, T.bf16()) - C_l1_memref_ty = T.memref(m, n, T.bf16()) + A_l2_memref_ty = T.memref(m * k * n_A_tiles_per_shim, dtype_in()) + B_l2_memref_ty = T.memref(k * n, dtype_in()) + C_l2_memref_ty = T.memref(m * n * n_aie_rows, dtype_out()) + A_l1_memref_ty = T.memref(m, k, dtype_in()) + B_l1_memref_ty = T.memref(k, n, dtype_in()) + C_l1_memref_ty = T.memref(m, n, dtype_out()) # AIE Core Function declarations zero_scalar = external_func("zero_scalar_bf16", inputs=[C_l1_memref_ty]) - zero = external_func("zero_bf16", inputs=[C_l1_memref_ty]) + zero_scalar = external_func( + f"zero_scalar_{dtype_out_str}", inputs=[C_l1_memref_ty] + ) + zero = external_func(f"zero_{dtype_out_str}", inputs=[C_l1_memref_ty]) matmul_scalar = external_func( "matmul_scalar_bf16_bf16", inputs=[A_l1_memref_ty, B_l1_memref_ty, C_l1_memref_ty], ) + matmul_scalar = external_func( + f"matmul_scalar_{dtype_in_str}_{dtype_out_str}", + inputs=[A_l1_memref_ty, B_l1_memref_ty, C_l1_memref_ty], + ) matmul = external_func( - "matmul_bf16_bf16", inputs=[A_l1_memref_ty, B_l1_memref_ty, C_l1_memref_ty] + f"matmul_{dtype_in_str}_{dtype_out_str}", + inputs=[A_l1_memref_ty, B_l1_memref_ty, C_l1_memref_ty], ) # Tile declarations as tile[row][col] @@ -250,9 +295,9 @@ def core_body(): # To/from AIE-array data movement @FuncOp.from_py_func( - T.memref(M * K, T.bf16()), - T.memref(K * N, T.bf16()), - T.memref(M * N, T.bf16()), + T.memref(M * K, dtype_in()), + T.memref(K * N, dtype_in()), + T.memref(M * N, dtype_out()), ) def sequence(A, B, C): # We are limited in the number of BDs. After synchronizing, we can reuse BDs.