Skip to content

Commit

Permalink
gemm int8 quantization (#5706)
Browse files Browse the repository at this point in the history
* quantize gemm

* write gemm quantize scales

* update doc

* less openmp args

* x86 riscv fallback

* skip gemm vulkan int8

* fix noint8 test, fix arm bf16 test

* enable vfpv4 on neon build only

* fix gemm vulkan without C

* fp16 pack8 output

* enable elempack=8 only for asimdhp+

* tiled gemm int8 test

* opt arm64 tiles, fix asimdhp dispatch
  • Loading branch information
nihui authored Oct 12, 2024
1 parent 9b5f6a3 commit 1c7af00
Show file tree
Hide file tree
Showing 24 changed files with 36,265 additions and 173 deletions.
28 changes: 16 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,25 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
endif()

if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON)

unset(CMAKE_REQUIRED_FLAGS)
else()
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
if(NCNN_COMPILER_SUPPORT_ARM_NEON)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
endif()
unset(CMAKE_REQUIRED_FLAGS)
else()
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

unset(CMAKE_REQUIRED_FLAGS)
if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
endif()

unset(CMAKE_REQUIRED_FLAGS)
endif()
endif()

if(NCNN_COMPILER_SUPPORT_ARM_VFPV4 OR NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
Expand Down
96 changes: 48 additions & 48 deletions cmake/ncnn_add_layer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,25 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__")
endif()
if(NCNN_AVX512VNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
endif()
if(NCNN_AVX512BF16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__")
endif()
if(NCNN_AVX512FP16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__")
endif()
if(NCNN_AVXVNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_AVX2)
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_XOP)
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__XOP__")
endif()
if(NCNN_F16C)
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__")
endif()
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
Expand All @@ -175,25 +175,25 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__")
endif()
if(NCNN_AVX512VNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
endif()
if(NCNN_AVX512BF16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__")
endif()
if(NCNN_AVX512FP16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__")
endif()
if(NCNN_AVXVNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_AVX2)
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_XOP)
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "/arch:AVX -mxop /D__SSSE3__ /D__SSE4_1__ /D__XOP__")
endif()
if(NCNN_F16C)
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX -mf16c /D__SSSE3__ /D__SSE4_1__ /D__F16C__")
endif()
else()
Expand All @@ -206,25 +206,25 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "-mavx")
endif()
if(NCNN_AVX512VNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni")
endif()
if(NCNN_AVX512BF16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
ncnn_add_arch_opt_source(${class} avx512bf16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16")
endif()
if(NCNN_AVX512FP16)
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
ncnn_add_arch_opt_source(${class} avx512fp16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16")
endif()
if(NCNN_AVXVNNI)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni")
endif()
if(NCNN_AVX2)
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c")
endif()
if(NCNN_XOP)
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "-mavx -mxop")
endif()
if(NCNN_F16C)
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "-mavx -mf16c")
endif()
endif()
Expand Down Expand Up @@ -254,28 +254,28 @@ macro(ncnn_add_layer class)
if(NCNN_ARM82)
ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM82DOT)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD")
endif()
if(NCNN_ARM82FP16FML)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML")
endif()
if(NCNN_ARM84BF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM84I8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8")
endif()
# TODO add support for sve family
if(NCNN_ARM86SVE)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
endif()
if(NCNN_ARM86SVE2)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
endif()
if(NCNN_ARM86SVEBF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
endif()
if(NCNN_ARM86SVEI8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
endif()
if(NCNN_ARM86SVEF32MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
endif()
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
if(NCNN_VFPV4)
Expand All @@ -284,28 +284,28 @@ macro(ncnn_add_layer class)
if(NCNN_ARM82)
ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 -march=armv8.2-a+fp16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM82DOT)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 -march=armv8.2-a+fp16+dotprod /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD")
endif()
if(NCNN_ARM82FP16FML)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 -march=armv8.2-a+fp16+fp16fml /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML")
endif()
if(NCNN_ARM84BF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+bf16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC")
endif()
if(NCNN_ARM84I8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+i8mm /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8")
endif()
# TODO add support for sve family
if(NCNN_ARM86SVE)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
endif()
if(NCNN_ARM86SVE2)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
endif()
if(NCNN_ARM86SVEBF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
endif()
if(NCNN_ARM86SVEI8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
endif()
if(NCNN_ARM86SVEF32MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
endif()
else()
if(NCNN_VFPV4)
Expand All @@ -314,31 +314,31 @@ macro(ncnn_add_layer class)
if(NCNN_ARM82)
ncnn_add_arch_opt_source(${class} asimdhp "-march=armv8.2-a+fp16")
endif()
if(NCNN_ARM82DOT)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
ncnn_add_arch_opt_source(${class} asimddp "-march=armv8.2-a+fp16+dotprod")
endif()
if(NCNN_ARM82FP16FML)
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
ncnn_add_arch_opt_source(${class} asimdfhm "-march=armv8.2-a+fp16+fp16fml")
endif()
if(NCNN_ARM84BF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+fp16+dotprod+bf16")
endif()
if(NCNN_ARM84I8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+fp16+dotprod+i8mm")
endif()
if(NCNN_ARM86SVE)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+fp16+dotprod+sve")
endif()
if(NCNN_ARM86SVE2)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+fp16+dotprod+sve2")
endif()
if(NCNN_ARM86SVEBF16)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+fp16+dotprod+sve+bf16")
endif()
if(NCNN_ARM86SVEI8MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+fp16+dotprod+sve+i8mm")
endif()
if(NCNN_ARM86SVEF32MM)
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+fp16+dotprod+sve+f32mm")
endif()
endif()
Expand Down
7 changes: 5 additions & 2 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -942,15 +942,18 @@ y = (gemm(a, b) + c * beta) * alpha
| 12 | output_elempack | int | 0 | |
| 13 | output_elemtype | int | 0 | |
| 14 | output_transpose | int| 0 | |
| 18 | int8_scale_term | int | 0 | |
| 20 | constant_TILE_M | int | 0 | |
| 21 | constant_TILE_N | int | 0 | |
| 22 | constant_TILE_K | int | 0 | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| A_data | float | [M, K] or [K, M] |
| B_data | float | [N, K] or [K, N] |
| A_data | float/fp16/int8 | [M, K] or [K, M] |
| B_data | float/fp16/int8 | [N, K] or [K, N] |
| C_data | float | [1], [M] or [N] or [1, M] or [N,1] or [N, M] |
| A_data_int8_scales| float | [M] |
| B_data_int8_scales| float | [1] |

# GridSample
```
Expand Down
Loading

0 comments on commit 1c7af00

Please sign in to comment.