Skip to content

Commit

Permalink
[matmul] int8 support (#1708)
Browse files Browse the repository at this point in the history
Co-authored-by: Joseph Melber <Joseph.melber@amd.com>
Co-authored-by: Joseph Melber <jgmelber@gmail.com>
  • Loading branch information
3 people authored Sep 9, 2024
1 parent 31b8618 commit fe0c224
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 6 deletions.
15 changes: 15 additions & 0 deletions aie_kernels/aie2/mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,20 @@ void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA,
pA, pB, pC);
}

template <unsigned m, unsigned k, unsigned n>
void matmul_vectorized_4x8x8_i8_i8(const int8 *__restrict pA,
const int8 *__restrict pB,
int8 *__restrict pC) {
constexpr int r = 4;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0 && m / (2 * r) > 0);
static_assert(k % (2 * s) == 0 && k / (2 * s) > 0);
static_assert(n % (2 * t) == 0 && n / (2 * t) > 0);
return matmul_vectorized<int8, int8, m / r, k / s, n / t, r, s, t>(pA, pB,
pC);
}

extern "C" {

// If you want to compile microkernels with different inner tile sizes,
Expand All @@ -558,6 +572,7 @@ extern "C" {
#endif

#define combos(X) \
X(int8, i8, int8, i8, 4, 8, 8) \
X(int16, i16, int16, i16, 4, 4, 4) \
X(int16, i16, int32, i32, 4, 4, 4) \
X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) \
Expand Down
30 changes: 30 additions & 0 deletions programming_examples/basic/matrix_multiplication/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ std::int16_t get_random<std::int16_t>() {
return (std::int16_t)rand() % 0x10000;
}

template <>
int8_t get_random<int8_t>() {
return (int8_t)rand() % 0x100;
}

template <>
std::bfloat16_t get_random<std::bfloat16_t>() {
// Random numbers should NOT be uniformly between 0 and 1, because that
Expand Down Expand Up @@ -195,6 +200,11 @@ float get_abs_tol<float>() {
return 0.5;
}

template <>
float get_abs_tol<int8_t>() {
return 0;
}

template <>
float get_rel_tol<std::int16_t>() {
return 0.0;
Expand All @@ -215,6 +225,11 @@ float get_rel_tol<float>() {
return 0.05;
}

template <>
float get_rel_tol<int8_t>() {
return 0;
}

template <typename T>
void print_matrix(const std::vector<T> matrix, int n_cols,
int n_printable_rows = 10, int n_printable_cols = 10,
Expand Down Expand Up @@ -275,6 +290,21 @@ void print_matrix(const std::vector<T> matrix, int n_cols,
#undef print_row
}

// int8_t aka char will not print as a number but as a character; specialize
// print_matrix<int8_t> to cast to int16_t first so everything prints as numbers
template <>
void print_matrix(const std::vector<int8_t> matrix, int n_cols,
int n_printable_rows, int n_printable_cols,
std::ostream &ostream, const char col_sep[],
const char elide_sym[], int w) {
std::vector<int16_t> cast_matrix(matrix.size());
for (int i = 0; i < matrix.size(); i++) {
cast_matrix[i] = (int16_t)matrix[i];
}
print_matrix(cast_matrix, n_cols, n_printable_rows, n_printable_cols, ostream,
col_sep, elide_sym, w);
}

constexpr int max_printable_errors = 32;

template <typename Tout>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ifeq ($(dtype_out),bf16)
dtype_out_cpp=std::bfloat16_t
dtype_acc_cpp=float
endif

ifeq ($(dtype_in),i16)
dtype_in_cpp=int16_t
endif
Expand All @@ -63,6 +62,13 @@ ifeq ($(dtype_out),f32)
dtype_out_cpp=float
dtype_acc_cpp=float
endif
ifeq ($(dtype_in),i8)
dtype_in_cpp=int8_t
endif
ifeq ($(dtype_out),i8)
dtype_out_cpp=int8_t
dtype_acc_cpp=int8_t
endif

trace_size?=65536

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def main():
argparser.add_argument("-k", type=int, default=64)
argparser.add_argument("-n", type=int, default=32)
argparser.add_argument(
"--dtype_in", type=str, choices=["bf16", "i16"], default="i16"
"--dtype_in", type=str, choices=["bf16", "i8", "i16"], default="i16"
)
argparser.add_argument(
"--dtype_out", type=str, choices=["bf16", "i16", "f32", "i32"], default="i32"
"--dtype_out",
type=str,
choices=["bf16", "i8", "i16", "f32", "i32"],
default="i32",
)
args = argparser.parse_args()
my_matmul(
Expand All @@ -53,6 +56,10 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str):
r = 4
s = 8
t = 4
elif dtype_in_str == "i8":
r = 4
s = 8
t = 8
elif dtype_in_str == "i16":
r = 4
s = 4
Expand All @@ -69,11 +76,15 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str):
dtype_in = None
if dtype_in_str == "bf16":
dtype_in = T.bf16
elif dtype_in_str == "i8":
dtype_in = T.i8
elif dtype_in_str == "i16":
dtype_in = T.i16
dtype_out = None
if dtype_out_str == "bf16":
dtype_out = T.bf16
elif dtype_out_str == "i8":
dtype_out = T.i8
elif dtype_out_str == "i16":
dtype_out = T.i16
elif dtype_out_str == "f32":
Expand Down Expand Up @@ -271,7 +282,7 @@ def sequence(A, B, C):
)

# only do 4 tile rows at a time before synchronizing, so we can reuse BDs
rows_per_block = 6
rows_per_block = 4
for tile_row_block in range(ceildiv(M_div_m, rows_per_block)):
# we only sync on half the BDs before reusing them, so the other half can concurrently keep running
# that's what this loop is for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//
// REQUIRES: ryzen_ai, chess
//
// RUN: mkdir -p %S/test_i8
// RUN: cd %S/test_i8
// RUN: make -f %S/Makefile clean
// RUN: make -f %S/Makefile
// RUN: %run_on_npu make -f %S/Makefile run | FileCheck %s
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// (c) Copyright 2024 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// REQUIRES: ryzen_ai, chess
//
// RUN: mkdir -p %S/test_4_col_i8
// RUN: cd %S/test_4_col_i8
// RUN: make -f %S/Makefile clean
// RUN: env dtype_in=i8 dtype_out=i8 m=64 k=128 n=64 M=512 K=512 N=512 make -f %S/Makefile
// RUN: %run_on_npu env dtype_in=i8 dtype_out=i8 m=64 k=128 n=64 M=512 K=512 N=512 make -f %S/Makefile run | FileCheck %s
// CHECK: PASS!
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def main():
argparser.add_argument("-n", type=int, default=32)
argparser.add_argument("--n-aie-cols", type=int, choices=[1, 2, 4], default=4)
argparser.add_argument(
"--dtype_in", type=str, choices=["bf16", "i16"], default="i16"
"--dtype_in", type=str, choices=["bf16", "i8", "i16"], default="i16"
)
argparser.add_argument(
"--dtype_out", type=str, choices=["bf16", "i16", "f32", "i32"], default="i16"
"--dtype_out",
type=str,
choices=["bf16", "i8", "i16", "f32", "i32"],
default="i16",
)
args = argparser.parse_args()
with mlir_mod_ctx() as ctx:
Expand Down Expand Up @@ -62,11 +65,15 @@ def my_matmul(M, K, N, m, k, n, n_aie_cols, dtype_in_str, dtype_out_str):
dtype_in = None
if dtype_in_str == "bf16":
dtype_in = T.bf16
elif dtype_in_str == "i8":
dtype_in = T.i8
elif dtype_in_str == "i16":
dtype_in = T.i16
dtype_out = None
if dtype_out_str == "bf16":
dtype_out = T.bf16
elif dtype_out_str == "i8":
dtype_out = T.i8
elif dtype_out_str == "i16":
dtype_out = T.i16
elif dtype_out_str == "f32":
Expand All @@ -78,6 +85,10 @@ def my_matmul(M, K, N, m, k, n, n_aie_cols, dtype_in_str, dtype_out_str):
r = 4
s = 8
t = 4
elif dtype_in_str == "i8":
r = 4
s = 8
t = 8
elif dtype_in_str == "i16":
r = 4
s = 4
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// (c) Copyright 2024 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// REQUIRES: ryzen_ai, chess
//
// RUN: mkdir -p %S/test_4_col_i8
// RUN: cd %S/test_4_col_i8
// RUN: make -f %S/Makefile clean
// RUN: env n_aie_cols=4 dtype_in=i8 dtype_out=i8 M=512 K=512 N=512 m=64 k=128 n=64 make -f %S/Makefile
// RUN: %run_on_npu env n_aie_cols=4 dtype_in=i8 dtype_out=i8 M=512 K=512 N=512 m=64 k=128 n=64 make -f %S/Makefile run | FileCheck %s
// CHECK: PASS!

0 comments on commit fe0c224

Please sign in to comment.