From fe0c224fb68170d688075902cf8266ead1e1cdcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20R=C3=B6sti?= Date: Mon, 9 Sep 2024 09:27:13 -0600 Subject: [PATCH] [matmul] int8 support (#1708) Co-authored-by: Joseph Melber Co-authored-by: Joseph Melber --- aie_kernels/aie2/mm.cc | 15 ++++++++++ .../basic/matrix_multiplication/common.h | 30 +++++++++++++++++++ .../matrix_multiplication/makefile-common | 8 ++++- .../matrix_multiplication/single_core/aie2.py | 17 +++++++++-- .../single_core/run_makefile.lit | 2 ++ .../single_core/run_makefile_i8.lit | 11 +++++++ .../matrix_multiplication/whole_array/aie2.py | 15 ++++++++-- .../whole_array/run_makefile_4_col_i8.lit | 11 +++++++ 8 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 programming_examples/basic/matrix_multiplication/single_core/run_makefile_i8.lit create mode 100644 programming_examples/basic/matrix_multiplication/whole_array/run_makefile_4_col_i8.lit diff --git a/aie_kernels/aie2/mm.cc b/aie_kernels/aie2/mm.cc index 51345596ff..fcef00e60d 100644 --- a/aie_kernels/aie2/mm.cc +++ b/aie_kernels/aie2/mm.cc @@ -538,6 +538,20 @@ void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA, pA, pB, pC); } +template +void matmul_vectorized_4x8x8_i8_i8(const int8 *__restrict pA, + const int8 *__restrict pB, + int8 *__restrict pC) { + constexpr int r = 4; + constexpr int s = 8; + constexpr int t = 8; + static_assert(m % (2 * r) == 0 && m / (2 * r) > 0); + static_assert(k % (2 * s) == 0 && k / (2 * s) > 0); + static_assert(n % (2 * t) == 0 && n / (2 * t) > 0); + return matmul_vectorized(pA, pB, + pC); +} + extern "C" { // If you want to compile microkernels with different inner tile sizes, @@ -558,6 +572,7 @@ extern "C" { #endif #define combos(X) \ + X(int8, i8, int8, i8, 4, 8, 8) \ X(int16, i16, int16, i16, 4, 4, 4) \ X(int16, i16, int32, i32, 4, 4, 4) \ X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) \ diff --git a/programming_examples/basic/matrix_multiplication/common.h b/programming_examples/basic/matrix_multiplication/common.h index cba6ff6363..25334d7d32 100644 --- a/programming_examples/basic/matrix_multiplication/common.h +++ b/programming_examples/basic/matrix_multiplication/common.h @@ -117,6 +117,11 @@ std::int16_t get_random() { return (std::int16_t)rand() % 0x10000; } +template <> +int8_t get_random() { + return (int8_t)rand() % 0x100; +} + template <> std::bfloat16_t get_random() { // Random numbers should NOT be uniformly between 0 and 1, because that @@ -195,6 +200,11 @@ float get_abs_tol() { return 0.5; } +template <> +float get_abs_tol() { + return 0; +} + template <> float get_rel_tol() { return 0.0; @@ -215,6 +225,11 @@ float get_rel_tol() { return 0.05; } +template <> +float get_rel_tol() { + return 0; +} + template void print_matrix(const std::vector matrix, int n_cols, int n_printable_rows = 10, int n_printable_cols = 10, @@ -275,6 +290,21 @@ void print_matrix(const std::vector matrix, int n_cols, #undef print_row } +// int8_t aka char will not print as a number but as a character; specialize +// print_matrix to cast to int16_t first so everything prints as numbers +template <> +void print_matrix(const std::vector matrix, int n_cols, + int n_printable_rows, int n_printable_cols, + std::ostream &ostream, const char col_sep[], + const char elide_sym[], int w) { + std::vector cast_matrix(matrix.size()); + for (int i = 0; i < matrix.size(); i++) { + cast_matrix[i] = (int16_t)matrix[i]; + } + print_matrix(cast_matrix, n_cols, n_printable_rows, n_printable_cols, ostream, + col_sep, elide_sym, w); +} + constexpr int max_printable_errors = 32; template diff --git a/programming_examples/basic/matrix_multiplication/makefile-common b/programming_examples/basic/matrix_multiplication/makefile-common index ba21462442..82c2ab8970 100644 --- a/programming_examples/basic/matrix_multiplication/makefile-common +++ b/programming_examples/basic/matrix_multiplication/makefile-common @@ -47,7 +47,6 @@ ifeq ($(dtype_out),bf16) dtype_out_cpp=std::bfloat16_t dtype_acc_cpp=float endif - ifeq ($(dtype_in),i16) dtype_in_cpp=int16_t endif @@ -63,6 +62,13 @@ ifeq ($(dtype_out),f32) dtype_out_cpp=float dtype_acc_cpp=float endif +ifeq ($(dtype_in),i8) + dtype_in_cpp=int8_t +endif +ifeq ($(dtype_out),i8) + dtype_out_cpp=int8_t + dtype_acc_cpp=int8_t +endif trace_size?=65536 diff --git a/programming_examples/basic/matrix_multiplication/single_core/aie2.py b/programming_examples/basic/matrix_multiplication/single_core/aie2.py index bb84ad7fd7..9e29b99073 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/aie2.py +++ b/programming_examples/basic/matrix_multiplication/single_core/aie2.py @@ -28,10 +28,13 @@ def main(): argparser.add_argument("-k", type=int, default=64) argparser.add_argument("-n", type=int, default=32) argparser.add_argument( - "--dtype_in", type=str, choices=["bf16", "i16"], default="i16" + "--dtype_in", type=str, choices=["bf16", "i8", "i16"], default="i16" ) argparser.add_argument( - "--dtype_out", type=str, choices=["bf16", "i16", "f32", "i32"], default="i32" + "--dtype_out", + type=str, + choices=["bf16", "i8", "i16", "f32", "i32"], + default="i32", ) args = argparser.parse_args() my_matmul( @@ -53,6 +56,10 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str): r = 4 s = 8 t = 4 + elif dtype_in_str == "i8": + r = 4 + s = 8 + t = 8 elif dtype_in_str == "i16": r = 4 s = 4 @@ -69,11 +76,15 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str): dtype_in = None if dtype_in_str == "bf16": dtype_in = T.bf16 + elif dtype_in_str == "i8": + dtype_in = T.i8 elif dtype_in_str == "i16": dtype_in = T.i16 dtype_out = None if dtype_out_str == "bf16": dtype_out = T.bf16 + elif dtype_out_str == "i8": + dtype_out = T.i8 elif dtype_out_str == "i16": dtype_out = T.i16 elif dtype_out_str == "f32": @@ -271,7 +282,7 @@ def sequence(A, B, C): ) # only do 4 tile rows at a time before synchronizing, so we can reuse BDs - rows_per_block = 6 + rows_per_block = 4 for tile_row_block in range(ceildiv(M_div_m, rows_per_block)): # we only sync on half the BDs before reusing them, so the other half can concurrently keep running # that's what this loop is for diff --git a/programming_examples/basic/matrix_multiplication/single_core/run_makefile.lit b/programming_examples/basic/matrix_multiplication/single_core/run_makefile.lit index 6875524001..3018ef7883 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/run_makefile.lit +++ b/programming_examples/basic/matrix_multiplication/single_core/run_makefile.lit @@ -3,6 +3,8 @@ // // REQUIRES: ryzen_ai, chess // +// RUN: mkdir -p %S/test_i8 +// RUN: cd %S/test_i8 // RUN: make -f %S/Makefile clean // RUN: make -f %S/Makefile // RUN: %run_on_npu make -f %S/Makefile run | FileCheck %s diff --git a/programming_examples/basic/matrix_multiplication/single_core/run_makefile_i8.lit b/programming_examples/basic/matrix_multiplication/single_core/run_makefile_i8.lit new file mode 100644 index 0000000000..541a201c4e --- /dev/null +++ b/programming_examples/basic/matrix_multiplication/single_core/run_makefile_i8.lit @@ -0,0 +1,11 @@ +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// REQUIRES: ryzen_ai, chess +// +// RUN: mkdir -p %S/test_4_col_i8 +// RUN: cd %S/test_4_col_i8 +// RUN: make -f %S/Makefile clean +// RUN: env dtype_in=i8 dtype_out=i8 m=64 k=128 n=64 M=512 K=512 N=512 make -f %S/Makefile +// RUN: %run_on_npu env dtype_in=i8 dtype_out=i8 m=64 k=128 n=64 M=512 K=512 N=512 make -f %S/Makefile run | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/basic/matrix_multiplication/whole_array/aie2.py b/programming_examples/basic/matrix_multiplication/whole_array/aie2.py index 0ef9c12d4f..1a7f82ee2a 100644 --- a/programming_examples/basic/matrix_multiplication/whole_array/aie2.py +++ b/programming_examples/basic/matrix_multiplication/whole_array/aie2.py @@ -28,10 +28,13 @@ def main(): argparser.add_argument("-n", type=int, default=32) argparser.add_argument("--n-aie-cols", type=int, choices=[1, 2, 4], default=4) argparser.add_argument( - "--dtype_in", type=str, choices=["bf16", "i16"], default="i16" + "--dtype_in", type=str, choices=["bf16", "i8", "i16"], default="i16" ) argparser.add_argument( - "--dtype_out", type=str, choices=["bf16", "i16", "f32", "i32"], default="i16" + "--dtype_out", + type=str, + choices=["bf16", "i8", "i16", "f32", "i32"], + default="i16", ) args = argparser.parse_args() with mlir_mod_ctx() as ctx: @@ -62,11 +65,15 @@ def my_matmul(M, K, N, m, k, n, n_aie_cols, dtype_in_str, dtype_out_str): dtype_in = None if dtype_in_str == "bf16": dtype_in = T.bf16 + elif dtype_in_str == "i8": + dtype_in = T.i8 elif dtype_in_str == "i16": dtype_in = T.i16 dtype_out = None if dtype_out_str == "bf16": dtype_out = T.bf16 + elif dtype_out_str == "i8": + dtype_out = T.i8 elif dtype_out_str == "i16": dtype_out = T.i16 elif dtype_out_str == "f32": @@ -78,6 +85,10 @@ def my_matmul(M, K, N, m, k, n, n_aie_cols, dtype_in_str, dtype_out_str): r = 4 s = 8 t = 4 + elif dtype_in_str == "i8": + r = 4 + s = 8 + t = 8 elif dtype_in_str == "i16": r = 4 s = 4 diff --git a/programming_examples/basic/matrix_multiplication/whole_array/run_makefile_4_col_i8.lit b/programming_examples/basic/matrix_multiplication/whole_array/run_makefile_4_col_i8.lit new file mode 100644 index 0000000000..c9007d1065 --- /dev/null +++ b/programming_examples/basic/matrix_multiplication/whole_array/run_makefile_4_col_i8.lit @@ -0,0 +1,11 @@ +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// REQUIRES: ryzen_ai, chess +// +// RUN: mkdir -p %S/test_4_col_i8 +// RUN: cd %S/test_4_col_i8 +// RUN: make -f %S/Makefile clean +// RUN: env n_aie_cols=4 dtype_in=i8 dtype_out=i8 M=512 K=512 N=512 m=64 k=128 n=64 make -f %S/Makefile +// RUN: %run_on_npu env n_aie_cols=4 dtype_in=i8 dtype_out=i8 M=512 K=512 N=512 m=64 k=128 n=64 make -f %S/Makefile run | FileCheck %s +// CHECK: PASS!