From ea10dbcf7b1b72887227bd0d370b3a5941d47e96 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Thu, 3 Oct 2024 21:20:19 -0700 Subject: [PATCH] Remove dependency on fp16 This dependency is causing issues on some platforms. We have a workaround: https://github.com/Maratyszcza/fp16 that some users have implemented (e.g. https://github.com/microsoft/onnxruntime/pull/22294/files), but it doesn't make sense to hack this workaround into everything that depends on XNNPACK. This PR pulls the small part of the fp16 library we actually need in XNNPACK, and removes the dependency. Fixes #6989 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/google/XNNPACK/pull/7128 from imaginationtech:img_patch30_f32_f16_vcvt 2f40edd2509c5dce91c93df3b12d0bd6da59f051 PiperOrigin-RevId: 682156616 --- BUILD.bazel | 17 +- CMakeLists.txt | 84 +- WORKSPACE | 13 - bench/BUILD.bazel | 2 - bench/gemm-benchmark.cc | 1 - build_params.bzl | 2 - cmake/DownloadFP16.cmake | 28 - cmake/gen/avxvnniint8_microkernels.cmake | 3 +- cmake/gen/rvvfp16arith_microkernels.cmake | 6 +- cmake/gen/scalar_microkernels.cmake | 2 + gen/avxvnniint8_microkernels.bzl | 1 + gen/rvvfp16arith_microkernels.bzl | 4 + gen/scalar_microkernels.bzl | 2 + scripts/generate-f32-f16-vcvt.sh | 6 + scripts/generate-x8-packw.sh | 7 + scripts/genxnn | 73 + src/configs/unary-elementwise-config.c | 2 +- .../gen/f16-qs8-vcvt-scalar-fmagic-u1.c | 1 - .../gen/f16-qs8-vcvt-scalar-fmagic-u2.c | 1 - .../gen/f16-qs8-vcvt-scalar-fmagic-u3.c | 1 - .../gen/f16-qs8-vcvt-scalar-fmagic-u4.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u1.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u2.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u3.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u4.c | 1 - src/f32-f16-vcvt/f32-f16-vcvt.h | 7 + .../gen/f32-f16-vcvt-rvvfp16arith-u1v.c | 40 + .../gen/f32-f16-vcvt-rvvfp16arith-u2v.c | 40 + .../gen/f32-f16-vcvt-rvvfp16arith-u4v.c | 40 + .../gen/f32-f16-vcvt-rvvfp16arith-u8v.c | 40 + src/f32-f16-vcvt/rvvfp16arith.c.in | 38 + src/f32-qs8-vcvt/scalar-fmagic.c.in | 2 - src/f32-qs8-vcvt/scalar-imagic.c.in | 2 - src/f32-rminmax/scalar.c.in | 4 +- src/operators/average-pooling-nhwc.c | 1 - src/operators/convolution-nchw.c | 1 - src/operators/convolution-nhwc.c | 1 - src/operators/deconvolution-nhwc.c | 1 - src/operators/dynamic-fully-connected-nc.c | 1 - src/operators/global-average-pooling-ncw.c | 2 +- src/operators/global-average-pooling-nwc.c | 2 +- src/operators/max-pooling-nhwc.c | 1 - .../scaled-dot-product-attention-nhtc.c | 1 - src/operators/softmax-nc.c | 1 - src/operators/unary-elementwise-nc.c | 1 - src/packing.cc | 1 - .../gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c | 5 +- src/qs8-gemm/scalar.c.in | 6 +- .../gen/qs8-packw-x16c8-gemm-goi-scalar.c | 2319 +++++++++++++++++ .../gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c | 444 ++++ .../gen/qs8-packw-x8c8-gemm-goi-scalar.c | 1175 +++++++++ src/qs8-packw/qs8-packw.h | 8 + src/subgraph.c | 2 +- src/subgraph/static-constant-pad.c | 2 +- src/x8-packw/kr-avxvnniint8.c.in | 401 +++ src/xnnpack/fp16.h | 179 ++ src/xnnpack/math.h | 2 +- src/xnnpack/quantization.h | 1 - src/xnnpack/simd/f16-scalar.h | 2 +- test/BUILD.bazel | 26 +- test/abs.cc | 1 - test/average-pooling-2d.cc | 1 - test/avgpool-microkernel-tester.h | 1 - test/bankers-rounding.cc | 1 - test/batch-matrix-multiply.cc | 1 - test/ceiling.cc | 1 - test/clamp.cc | 1 - test/concatenate2.cc | 1 - test/concatenate3.cc | 1 - test/concatenate4.cc | 1 - test/concatenate5.cc | 1 - test/conv-hwc2chw-microkernel-tester.h | 1 - test/convert-operator-tester.h | 1 - test/convert.cc | 1 - test/convolution-2d.cc | 1 - test/convolution-operator-tester.h | 1 - test/copy.cc | 1 - test/deconvolution-2d.cc | 1 - test/deconvolution-operator-tester.h | 1 - test/depth-to-space-2d.cc | 1 - test/depthwise-convolution-2d.cc | 1 - test/dwconv-microkernel-tester.cc | 1 - test/dwconv2d-microkernel-tester.h | 1 - test/elu.cc | 1 - test/even-split2.cc | 1 - test/even-split3.cc | 1 - test/even-split4.cc | 1 - test/f16-simd-scalar.cc | 2 +- test/f16-simd.cc.in | 2 +- test/floor.cc | 1 - test/fully-connected-operator-tester.h | 1 - test/fully-connected.cc | 1 - test/gavgpool-cw-microkernel-tester.h | 2 +- test/gavgpool-microkernel-tester.h | 1 - test/gemm-microkernel-tester.cc | 2 - test/global-average-pooling-1d.cc | 1 - test/global-average-pooling-2d.cc | 1 - test/global-sum-pooling-1d.cc | 1 - test/global-sum-pooling-2d.cc | 1 - test/hardswish.cc | 1 - test/ibilinear-microkernel-tester.h | 1 - test/leaky-relu.cc | 1 - test/max-pooling-2d.cc | 1 - test/maxpool-microkernel-tester.h | 1 - test/negate.cc | 1 - test/packing.cc | 1 - test/prelu-microkernel-tester.h | 1 - test/prelu.cc | 1 - .../raddstoreexpminusmax-microkernel-tester.h | 1 - test/rdsum-microkernel-tester.h | 1 - test/reciprocal-square-root.cc | 1 - test/rsum-microkernel-tester.h | 1 - test/scaled-dot-product-attention.cc | 1 - test/sigmoid.cc | 1 - test/softmax.cc | 1 - test/space-to-depth-2d.cc | 1 - test/spmm-microkernel-tester.h | 1 - test/square-root.cc | 1 - test/square.cc | 1 - test/static-constant-pad.cc | 1 - test/static-expand-dims.cc | 1 - test/static-mean.cc | 1 - test/static-reshape.cc | 1 - test/static-resize-bilinear-2d.cc | 1 - test/static-slice.cc | 1 - test/static-transpose.cc | 1 - test/subgraph-fp16.cc | 1 - test/tanh-operator-tester.h | 1 - test/tanh.cc | 1 - test/unary-operator-tester.cc | 1 - test/vbinary-microkernel-tester.cc | 1 - test/vcmul-microkernel-tester.h | 1 - test/vcvt-microkernel-tester.cc | 1 - test/vmulcaddc-microkernel-tester.h | 1 - test/vunary-microkernel-tester.cc | 1 - test/vunary-microkernel-tester.h | 1 - 142 files changed, 4884 insertions(+), 286 deletions(-) delete mode 100644 cmake/DownloadFP16.cmake create mode 100755 scripts/genxnn create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c create mode 100644 src/f32-f16-vcvt/rvvfp16arith.c.in create mode 100644 src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c create mode 100644 src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c create mode 100644 src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c create mode 100644 src/x8-packw/kr-avxvnniint8.c.in create mode 100644 src/xnnpack/fp16.h diff --git a/BUILD.bazel b/BUILD.bazel index baa9cec726b..e54d106914b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -327,7 +327,6 @@ xnnpack_cc_library( ":microparams", ":mutex", ":xnnpack_h", - "@FP16", ], ) @@ -428,7 +427,6 @@ xnnpack_cc_library( ":packing", ":prod_microkernels", ":xnnpack_h", - "@FP16", ] + select({ ":cpuinfo_enabled": ["@cpuinfo"], "//conditions:default": [], @@ -449,13 +447,19 @@ xnnpack_cc_library( ], ) +xnnpack_cc_library( + name = "fp16", + hdrs = ["src/xnnpack/fp16.h"], + compatible_with = [], +) + xnnpack_cc_library( name = "math", hdrs = ["src/xnnpack/math.h"], deps = [ ":common", ":config_hdrs", - "@FP16", + ":fp16", ], ) @@ -598,7 +602,6 @@ xnnpack_cc_library( ":common", ":math", ":microparams", - "@FP16", ], ) @@ -777,7 +780,6 @@ xnnpack_cc_library( ":microparams", ":operator_h", ":xnnpack_h", - "@FP16", "@FXdiv", ], ) @@ -796,7 +798,6 @@ xnnpack_cxx_library( ":params", ":unaligned", ":xnnpack_h", - "@FP16", ] + xnnpack_if_kleidiai_enabled([ "@KleidiAI//kai/ukernels/matmul", "@KleidiAI//kai/ukernels/matmul:rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", @@ -931,6 +932,7 @@ xnnpack_cc_library( ":allocator", ":cache", ":common", + ":fp16", ":indirection", ":logging", ":math", @@ -946,7 +948,6 @@ xnnpack_cc_library( ":params", ":quantization", ":xnnpack_h", - "@FP16", "@pthreadpool", ] + select({ "//conditions:default": [], @@ -969,6 +970,7 @@ xnnpack_cc_library( ":cache", ":common", ":config_hdrs", + ":fp16", ":hardware_config", ":internal", ":logging", @@ -983,7 +985,6 @@ xnnpack_cc_library( ":params", ":requantization", ":xnnpack_h", - "@FP16", "@pthreadpool", ] + xnnpack_slinky_deps(), ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 948c69d8243..1fdaf28f02a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -304,16 +304,6 @@ IF(NOT XNNPACK_USE_SYSTEM_LIBS) SET(CPUINFO_SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo-source" CACHE STRING "cpuinfo source directory") ENDIF() - IF(NOT DEFINED FP16_SOURCE_DIR) - MESSAGE(STATUS "Downloading FP16 to ${CMAKE_BINARY_DIR}/FP16-source (define FP16_SOURCE_DIR to avoid it)") - CONFIGURE_FILE(cmake/DownloadFP16.cmake "${CMAKE_BINARY_DIR}/FP16-download/CMakeLists.txt") - EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") - EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") - SET(FP16_SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" CACHE STRING "FP16 source directory") - ENDIF() - IF(NOT DEFINED FXDIV_SOURCE_DIR) MESSAGE(STATUS "Downloading FXdiv to ${CMAKE_BINARY_DIR}/FXdiv-source (define FXDIV_SOURCE_DIR to avoid it)") CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CMAKE_BINARY_DIR}/FXdiv-download/CMakeLists.txt") @@ -1116,42 +1106,7 @@ IF(XNNPACK_BUILD_LIBRARY) TARGET_LINK_LIBRARIES(XNNPACK PRIVATE fxdiv) ENDIF() -# ---[ Configure FP16 -IF(NOT TARGET fp16) - IF(NOT XNNPACK_USE_SYSTEM_LIBS) - SET(FP16_BUILD_TESTS OFF CACHE BOOL "") - SET(FP16_BUILD_BENCHMARKS OFF CACHE BOOL "") - ADD_SUBDIRECTORY( - "${FP16_SOURCE_DIR}" - "${CMAKE_BINARY_DIR}/FP16") - ELSE() - FIND_FILE(FP16_HDR fp16.h PATH_SUFFIXES include PATHS "${FP16_SOURCE_DIR}") - IF(NOT FP16_HDR) - MESSAGE(FATAL_ERROR "Cannot find fp16") - ENDIF() - ADD_LIBRARY(fp16 STATIC "${FP16_HDR}") - TARGET_INCLUDE_DIRECTORIES(fp16 INTERFACE "${FP16_SOURCE_DIR}/include") - SET_PROPERTY(TARGET fp16 PROPERTY LINKER_LANGUAGE C) - ENDIF() -ENDIF() -IF(XNNPACK_BUILD_ALL_MICROKERNELS) - TARGET_LINK_LIBRARIES(microkernels-all PRIVATE fp16) -ENDIF() -TARGET_LINK_LIBRARIES(microkernels-prod PRIVATE fp16) -TARGET_LINK_LIBRARIES(microparams-init PRIVATE fp16) -TARGET_LINK_LIBRARIES(packing PRIVATE fp16) -TARGET_LINK_LIBRARIES(indirection PRIVATE fp16) -TARGET_LINK_LIBRARIES(memory PRIVATE fp16) -TARGET_LINK_LIBRARIES(normalization PRIVATE fp16) -TARGET_LINK_LIBRARIES(microkernel-utils PRIVATE fp16) -TARGET_LINK_LIBRARIES(cache PRIVATE fp16) -TARGET_LINK_LIBRARIES(operator-utils PRIVATE fp16) IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(subgraph PRIVATE fp16) - TARGET_LINK_LIBRARIES(operators PRIVATE fp16) - TARGET_LINK_LIBRARIES(operator-run PRIVATE fp16) - - TARGET_LINK_LIBRARIES(XNNPACK PRIVATE fp16) INSTALL(TARGETS XNNPACK LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -1212,7 +1167,7 @@ IF(XNNPACK_BUILD_TESTS) ADD_LIBRARY(gemm-microkernel-tester STATIC test/gemm-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(gemm-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE xnnpack-base fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE xnnpack-base pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE packing) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE kleidiai) @@ -1221,34 +1176,34 @@ IF(XNNPACK_BUILD_TESTS) ADD_LIBRARY(unary-operator-tester STATIC test/unary-operator-tester.cc) TARGET_INCLUDE_DIRECTORIES(unary-operator-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(unary-operator-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(unary-operator-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(dwconv-microkernel-tester STATIC test/dwconv-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(dwconv-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PUBLIC next-prime) ADD_LIBRARY(vbinary-microkernel-tester STATIC test/vbinary-microkernel-tester.cc) SET_TARGET_PROPERTIES(vbinary-microkernel-tester PROPERTIES CXX_EXTENSIONS YES) TARGET_INCLUDE_DIRECTORIES(vbinary-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vbinary-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vbinary-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(vcvt-microkernel-tester STATIC test/vcvt-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(vcvt-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vcvt-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vcvt-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(vunary-microkernel-tester STATIC test/vunary-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(vunary-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vunary-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vunary-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(vunary-microkernel-tester PUBLIC next-prime) ADD_LIBRARY(convolution-test-helpers OBJECT test/convolution-test-helpers.cc) TARGET_INCLUDE_DIRECTORIES(convolution-test-helpers PRIVATE include src) - TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE xnnpack-base fp16) + TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE xnnpack-base) ADD_LIBRARY(packq-microkernel-tester STATIC test/packq-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(packq-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE kleidiai) ENDIF() @@ -1268,7 +1223,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gtest GTest::gtest_main hardware-config @@ -1295,7 +1249,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1333,7 +1286,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gtest GTest::gtest_main unary-operator-tester @@ -1344,7 +1296,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(binary-elementwise-nd-test test/binary-elementwise-nd.cc) TARGET_INCLUDE_DIRECTORIES(binary-elementwise-nd-test PRIVATE src test) TARGET_LINK_LIBRARIES(binary-elementwise-nd-test PRIVATE - fp16 GTest::gtest GTest::gtest_main XNNPACK) @@ -1362,7 +1313,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1425,7 +1375,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1444,7 +1393,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE convolution-test-helpers - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1539,7 +1487,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1580,7 +1527,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE dwconv-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1633,7 +1579,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE gemm-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1654,7 +1599,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE packq-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1727,7 +1671,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vbinary-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1760,7 +1703,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vcvt-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1814,7 +1756,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vunary-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1849,7 +1790,7 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(operator-utils-test test/operator-utils.cc) TARGET_INCLUDE_DIRECTORIES(operator-utils-test PRIVATE include src) - TARGET_LINK_LIBRARIES(operator-utils-test PRIVATE XNNPACK GTest::gtest GTest::gtest_main pthreadpool fp16) + TARGET_LINK_LIBRARIES(operator-utils-test PRIVATE XNNPACK GTest::gtest GTest::gtest_main pthreadpool) ENDIF() @@ -1903,14 +1844,14 @@ IF(XNNPACK_BUILD_BENCHMARKS) # Helper libraries ADD_LIBRARY(packq-benchmark STATIC bench/packq-benchmark.cc) TARGET_INCLUDE_DIRECTORIES(packq-benchmark PRIVATE . include src bench) - TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils fp16) + TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE kleidiai) ENDIF() ADD_LIBRARY(gemm-benchmark STATIC bench/gemm-benchmark.cc) TARGET_INCLUDE_DIRECTORIES(gemm-benchmark PRIVATE . include src bench) - TARGET_LINK_LIBRARIES(gemm-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils fp16) + TARGET_LINK_LIBRARIES(gemm-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(gemm-benchmark PUBLIC kleidiai) ENDIF() @@ -1936,7 +1877,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(bench-models PRIVATE bench-utils benchmark::benchmark - fp16 models XNNPACK) @@ -1971,7 +1911,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(${BENCH}-bench PRIVATE bench-utils benchmark::benchmark - fp16 XNNPACK ) ENDFOREACH() @@ -2068,7 +2007,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(${BENCH}-bench PRIVATE bench-utils benchmark::benchmark - fp16 gemm-benchmark hardware-config im2col diff --git a/WORKSPACE b/WORKSPACE index 54d3841939e..b140e022e05 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -63,19 +63,6 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadGoogleBenchmark.cmake) -# LINT.IfChange -# FP16 library, used for half-precision conversions -http_archive( - name = "FP16", - build_file = "@//third_party:FP16.BUILD", - sha256 = "e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70", - strip_prefix = "FP16-0a92994d729ff76a58f692d3028ca1b64b145d91", - urls = [ - "https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip", - ], -) -# LINT.ThenChange(cmake/DownloadFP16.cmake) - # LINT.IfChange # FXdiv library, used for repeated integer division by the same factor http_archive( diff --git a/bench/BUILD.bazel b/bench/BUILD.bazel index a8758ea9539..fbe9444eadf 100644 --- a/bench/BUILD.bazel +++ b/bench/BUILD.bazel @@ -27,7 +27,6 @@ load( MICROKERNEL_BENCHMARK_DEPS = [ ":bench_utils", - "@FP16", "//:aligned_allocator", "//:all_microkernels", "//:common", @@ -48,7 +47,6 @@ OPERATOR_BENCHMARK_DEPS = [ "//:cache", "//:common", "//:math", - "@FP16", ] xnnpack_cxx_library( diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index 2a9b6576529..116acbfdbd2 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -25,7 +25,6 @@ #include #include -#include #include "bench/utils.h" #include diff --git a/build_params.bzl b/build_params.bzl index 6aec1584971..92b4025364d 100644 --- a/build_params.bzl +++ b/build_params.bzl @@ -274,7 +274,6 @@ XNNPACK_PARAMS_FOR_ARCH = { ], extra_deps = [ "//:config_hdrs", - "@FP16", "@FXdiv", ], ), @@ -523,7 +522,6 @@ XNNPACK_PARAMS_FOR_ARCH = { "-mno-sse4.2", ], extra_deps = [ - "@FP16", ], msvc_x86_32_copts = ["/arch:SSE2"], msvc_x86_64_copts = ["/arch:SSE2"], diff --git a/cmake/DownloadFP16.cmake b/cmake/DownloadFP16.cmake deleted file mode 100644 index a3321d9d40b..00000000000 --- a/cmake/DownloadFP16.cmake +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# Copyright 2019 Google LLC -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) - -PROJECT(fp16-download NONE) - -# Set file timestamps to the time of extraction. -IF(POLICY CMP0135) - CMAKE_POLICY(SET CMP0135 NEW) -ENDIF() - -INCLUDE(ExternalProject) -ExternalProject_Add(fp16 - URL https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip - URL_HASH SHA256=e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70 - SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" - BINARY_DIR "${CMAKE_BINARY_DIR}/FP16" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) diff --git a/cmake/gen/avxvnniint8_microkernels.cmake b/cmake/gen/avxvnniint8_microkernels.cmake index 60e0412af5f..ee2ac244a81 100644 --- a/cmake/gen/avxvnniint8_microkernels.cmake +++ b/cmake/gen/avxvnniint8_microkernels.cmake @@ -15,6 +15,7 @@ SET(PROD_AVXVNNIINT8_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8c8-minmax-fp32-avxvnniint8-prfm.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x8c8-minmax-fp32-avxvnniint8-prfm.c) -SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS) +SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c) SET(ALL_AVXVNNIINT8_MICROKERNEL_SRCS ${PROD_AVXVNNIINT8_MICROKERNEL_SRCS} + ${NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS}) diff --git a/cmake/gen/rvvfp16arith_microkernels.cmake b/cmake/gen/rvvfp16arith_microkernels.cmake index 0f32a4a46ae..03df4b84f53 100644 --- a/cmake/gen/rvvfp16arith_microkernels.cmake +++ b/cmake/gen/rvvfp16arith_microkernels.cmake @@ -15,6 +15,10 @@ SET(NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u1v.c src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u2v.c src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u4v.c - src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c) + src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c) SET(ALL_RVVFP16ARITH_MICROKERNEL_SRCS ${PROD_RVVFP16ARITH_MICROKERNEL_SRCS} + ${NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS}) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index d30bceb4257..8879af0b608 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -635,7 +635,9 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c diff --git a/gen/avxvnniint8_microkernels.bzl b/gen/avxvnniint8_microkernels.bzl index 97f1657ee4a..a1b149b017c 100644 --- a/gen/avxvnniint8_microkernels.bzl +++ b/gen/avxvnniint8_microkernels.bzl @@ -13,6 +13,7 @@ PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ ] NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c", ] ALL_AVXVNNIINT8_MICROKERNEL_SRCS = PROD_AVXVNNIINT8_MICROKERNEL_SRCS + NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS diff --git a/gen/rvvfp16arith_microkernels.bzl b/gen/rvvfp16arith_microkernels.bzl index 139aa872254..3d0d9afc559 100644 --- a/gen/rvvfp16arith_microkernels.bzl +++ b/gen/rvvfp16arith_microkernels.bzl @@ -13,6 +13,10 @@ NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS = [ "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u2v.c", "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u4v.c", "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c", ] ALL_RVVFP16ARITH_MICROKERNEL_SRCS = PROD_RVVFP16ARITH_MICROKERNEL_SRCS + NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index 7ef2fc7f766..21e3b9f9b0c 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -632,7 +632,9 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c", "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c", "src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c", diff --git a/scripts/generate-f32-f16-vcvt.sh b/scripts/generate-f32-f16-vcvt.sh index e34f227224f..d7dae496c53 100755 --- a/scripts/generate-f32-f16-vcvt.sh +++ b/scripts/generate-f32-f16-vcvt.sh @@ -13,6 +13,12 @@ tools/xngen src/f32-f16-vcvt/neon.c.in -D BATCH_TILE=32 -o src/f32-f16-vcvt/gen/ tools/xngen src/f32-f16-vcvt/neonfp16.c.in -D BATCH_TILE=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-neonfp16-u8.c & tools/xngen src/f32-f16-vcvt/neonfp16.c.in -D BATCH_TILE=16 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-neonfp16-u16.c & +################################ RISC-V Vector ################################ +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=1 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=2 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=4 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c & + ################################# x86 128-bit ################################# tools/xngen src/f32-f16-vcvt/sse.c.in -D SSE=2 -D AVX=0 -D BATCH_TILE=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-sse2-u8.c & tools/xngen src/f32-f16-vcvt/sse.c.in -D SSE=2 -D AVX=0 -D BATCH_TILE=16 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-sse2-u16.c & diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 281bba59237..912f9400796 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -23,4 +23,11 @@ tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=4 -D TYPE=int8_t -o src/q tools/xngen src/x8-packw/kr-scalar.c.in -D NR=32 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c & tools/xngen src/x8-packw/kr-scalar.c.in -D NR=64 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c & + +### C8 packing +tools/xngen src/x8-packw/kr-scalar.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c & +tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c & + +tools/xngen src/x8-packw/kr-avxvnniint8.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c & + wait diff --git a/scripts/genxnn b/scripts/genxnn new file mode 100755 index 00000000000..12a5888c8c9 --- /dev/null +++ b/scripts/genxnn @@ -0,0 +1,73 @@ +#!/bin/bash + +function my_generate1() { + files_changed=$(scripts/check_files_changed.py $1) + if [[ $files_changed ]] + then + if ${VERBOSE}; then + echo $1 + fi + if ${DEBUG}; then + echo $files_changed + fi + touch $files_changed + bash -c $1 & + fi +} + +function my_generatef() { + if ${VERBOSE}; then + echo $1 + fi + bash -c $1 & +} + +export -f my_generate1 +export -f my_generatef + +export FORCE='false' +export VERBOSE='false' +export DEBUG='false' + +while getopts ':vfd' 'OPTKEY'; do + case ${OPTKEY} in + 'v') + export VERBOSE='true' + ;; + 'd') + export DEBUG='true' + ;; + 'f') + export FORCE='true' + ;; + '?') + echo "INVALID OPTION -- ${OPTARG}" >&2 + exit 1 + ;; + ':') + echo "MISSING ARGUMENT for option -- ${OPTARG}" >&2 + exit 1 + ;; + *) + echo "UNIMPLEMENTED OPTION -- ${OPTKEY}" >&2 + exit 1 + ;; + esac +done + +# [optional] Remove all options processed by getopts. +shift $(( OPTIND - 1 )) +[[ "${1}" == "--" ]] && shift + +#pushd $(g4 g4d) +if ${FORCE}; then + find scripts/ -name 'generate-*.sh' -exec bash -c 'my_generatef {}' \; +else + find scripts/ -name 'generate-*.sh' -exec bash -c 'my_generate1 {}' \; +fi +wait +if ${VERBOSE}; then + echo ./tools/update-microkernels.py +fi +./tools/update-microkernels.py +#popd diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index b464b1af0ba..26a7d8250d6 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -1999,7 +1999,7 @@ static void init_qs8_lrelu_config(void) { } static void init_qs8_to_f16_cvt_config(void) { - #if XNN_ARCH_ARM || XNN_ARCH_ARM64 + #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); if (hardware_config->use_arm_neon_fp16_arith) { diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c index 20469b41d53..fb9e4dc17b6 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u1( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c index 017b78406bf..e3561a3ccb3 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u2( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c index c88ab86b633..f088ddf5684 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u3( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c index c9bd1b86128..c2cb7a3594e 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u4( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c index 7d9bda6d202..c995a3df11a 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u1( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c index 82074e6fcaf..00620ce80b3 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u2( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c index 3ccb4aaff00..c4f849e687e 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u3( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c index 979725bdce9..86e42c4ca57 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u4( size_t batch, diff --git a/src/f32-f16-vcvt/f32-f16-vcvt.h b/src/f32-f16-vcvt/f32-f16-vcvt.h index dfd6fb266f5..fbcaa08526d 100644 --- a/src/f32-f16-vcvt/f32-f16-vcvt.h +++ b/src/f32-f16-vcvt/f32-f16-vcvt.h @@ -58,6 +58,13 @@ XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__wasmrelaxedsimd_u24, 24 XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__wasmrelaxedsimd_u32, 32, false, float, xnn_float16, void, NULL) #endif // XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_FP16_VECTOR +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u1v, 1, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u2v, 2, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u4v, 4, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u8v, 8, true, float, xnn_float16, void, NULL) +#endif + XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u1, 1, false, float, xnn_float16, void, NULL) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u2, 2, false, float, xnn_float16, void, NULL) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u3, 3, false, float, xnn_float16, void, NULL) diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c new file mode 100644 index 00000000000..e84b6d04699 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u1v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m1(batch); batch -= n; + + vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n; + + vfloat16mf2_t y_f16v = __riscv_vfncvt_f_f_w_f16mf2(x_f32v, n); + + __riscv_vse16_v_f16mf2(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c new file mode 100644 index 00000000000..69dd22f841e --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u2v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m2(batch); batch -= n; + + vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n; + + vfloat16m1_t y_f16v = __riscv_vfncvt_f_f_w_f16m1(x_f32v, n); + + __riscv_vse16_v_f16m1(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c new file mode 100644 index 00000000000..116bb4d35bb --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u4v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m4(batch); batch -= n; + + vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n; + + vfloat16m2_t y_f16v = __riscv_vfncvt_f_f_w_f16m2(x_f32v, n); + + __riscv_vse16_v_f16m2(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c new file mode 100644 index 00000000000..dc0cadfb239 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u8v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m8(batch); batch -= n; + + vfloat32m8_t x_f32v = __riscv_vle32_v_f32m8(input, n); input += n; + + vfloat16m4_t y_f16v = __riscv_vfncvt_f_f_w_f16m4(x_f32v, n); + + __riscv_vse16_v_f16m4(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/rvvfp16arith.c.in b/src/f32-f16-vcvt/rvvfp16arith.c.in new file mode 100644 index 00000000000..da136047268 --- /dev/null +++ b/src/f32-f16-vcvt/rvvfp16arith.c.in @@ -0,0 +1,38 @@ +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert LMUL in [1, 2, 4, 8] +$LMUL_16 = {1: "f2", 2: "1", 4: "2", 8: "4"}[LMUL] +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u${LMUL}v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m${LMUL}(batch); batch -= n; + + vfloat32m${LMUL}_t x_f32v = __riscv_vle32_v_f32m${LMUL}(input, n); input += n; + + vfloat16m${LMUL_16}_t y_f16v = __riscv_vfncvt_f_f_w_f16m${LMUL_16}(x_f32v, n); + + __riscv_vse16_v_f16m${LMUL_16}(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-qs8-vcvt/scalar-fmagic.c.in b/src/f32-qs8-vcvt/scalar-fmagic.c.in index 7c3bfdfa0ba..94218c1745a 100644 --- a/src/f32-qs8-vcvt/scalar-fmagic.c.in +++ b/src/f32-qs8-vcvt/scalar-fmagic.c.in @@ -12,8 +12,6 @@ $assert IDATATYPE == "F16" and ODATATYPE == "QS8" or IDATATYPE == "F32" #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -$if IDATATYPE == "F16": - #include $INPUT_T = {"F16": "xnn_float16", "F32": "float"}[IDATATYPE] $XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[ODATATYPE] diff --git a/src/f32-qs8-vcvt/scalar-imagic.c.in b/src/f32-qs8-vcvt/scalar-imagic.c.in index ae1e474b3ec..2d83fb24ed0 100644 --- a/src/f32-qs8-vcvt/scalar-imagic.c.in +++ b/src/f32-qs8-vcvt/scalar-imagic.c.in @@ -12,8 +12,6 @@ $assert IDATATYPE == "F16" and ODATATYPE == "QS8" or IDATATYPE == "F32" #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -$if IDATATYPE == "F16": - #include $INPUT_T = {"F16": "xnn_float16", "F32": "float"}[IDATATYPE] $XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[ODATATYPE] diff --git a/src/f32-rminmax/scalar.c.in b/src/f32-rminmax/scalar.c.in index e26a6584b6c..1fbf58ff951 100644 --- a/src/f32-rminmax/scalar.c.in +++ b/src/f32-rminmax/scalar.c.in @@ -12,7 +12,7 @@ $if not WASM: #include "xnnpack/math.h" #include "xnnpack/reduce.h" $if DATATYPE == "F16": - #include + #include "xnnpack/fp16.h" $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS $MAX = "__builtin_wasm_max_f32" if WASM else "math_max_f32" @@ -30,7 +30,7 @@ void xnn_${DATATYPE.lower()}_r${OP.lower()}_ukernel__${ISA}_u${BATCH_TILE}${ACC_ $elif DATATYPE == "F16": const void* input, void* output, - const union xnn_${DATATYPE.lower()}_default_params params[restrict XNN_MIN_ELEMENTS(1)]) + const struct xnn_${DATATYPE.lower()}_default_params params[restrict XNN_MIN_ELEMENTS(1)]) { assert(batch != 0); assert(batch % sizeof(${ITYPE}) == 0); diff --git a/src/operators/average-pooling-nhwc.c b/src/operators/average-pooling-nhwc.c index 754da020202..7a217c37d85 100644 --- a/src/operators/average-pooling-nhwc.c +++ b/src/operators/average-pooling-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/convolution-nchw.c b/src/operators/convolution-nchw.c index fc4fe8c9d5b..6be0bf5a2b1 100644 --- a/src/operators/convolution-nchw.c +++ b/src/operators/convolution-nchw.c @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 0125b6cb3cb..e0eb6265ccb 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/deconvolution-nhwc.c b/src/operators/deconvolution-nhwc.c index 60a44bc158b..ddfe9650723 100644 --- a/src/operators/deconvolution-nhwc.c +++ b/src/operators/deconvolution-nhwc.c @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index ca5d69fdc6f..c3eb10fab03 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/global-average-pooling-ncw.c b/src/operators/global-average-pooling-ncw.c index ef5ce1a7749..6e7c3700200 100644 --- a/src/operators/global-average-pooling-ncw.c +++ b/src/operators/global-average-pooling-ncw.c @@ -10,13 +10,13 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microparams-init.h" diff --git a/src/operators/global-average-pooling-nwc.c b/src/operators/global-average-pooling-nwc.c index fa50945dc5b..e04b7943258 100644 --- a/src/operators/global-average-pooling-nwc.c +++ b/src/operators/global-average-pooling-nwc.c @@ -14,13 +14,13 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microparams.h" diff --git a/src/operators/max-pooling-nhwc.c b/src/operators/max-pooling-nhwc.c index 43ad86552cd..e8dd7413d37 100644 --- a/src/operators/max-pooling-nhwc.c +++ b/src/operators/max-pooling-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/scaled-dot-product-attention-nhtc.c b/src/operators/scaled-dot-product-attention-nhtc.c index 1de22ee1d87..a278c7c044f 100644 --- a/src/operators/scaled-dot-product-attention-nhtc.c +++ b/src/operators/scaled-dot-product-attention-nhtc.c @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/softmax-nc.c b/src/operators/softmax-nc.c index 0d4db7f1443..75e57edc9c9 100644 --- a/src/operators/softmax-nc.c +++ b/src/operators/softmax-nc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c index 0cd6baa9955..2270934695a 100644 --- a/src/operators/unary-elementwise-nc.c +++ b/src/operators/unary-elementwise-nc.c @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/packing.cc b/src/packing.cc index 9c742ad6e8c..c37b06b5902 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -28,7 +28,6 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #endif // XNN_ENABLE_KLEIDIAI -#include extern "C" { diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c index 80836d13ca5..5be5406c57f 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c @@ -12,7 +12,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar( @@ -98,11 +97,11 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar( w = (const float*) w + 2; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c index b84b84c0f1a..215aa89d397 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar( @@ -127,13 +126,13 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout0x2 = math_max_f32(vout0x2, voutput_min); vout0x3 = math_max_f32(vout0x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); vout0x2 = math_min_f32(vout0x2, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c index 73bf59de5a0..d42717834c3 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( @@ -187,7 +186,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( w = (const float*) w + 8; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout0x2 = math_max_f32(vout0x2, voutput_min); @@ -197,7 +196,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( vout0x6 = math_max_f32(vout0x6, voutput_min); vout0x7 = math_max_f32(vout0x7, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); vout0x2 = math_min_f32(vout0x2, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c index 660ea399c00..60f32047731 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c @@ -12,7 +12,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar( @@ -127,13 +126,13 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar( w = (const float*) w + 2; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout1x1 = math_max_f32(vout1x1, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c index 580148c13b1..a75b95cc64f 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( @@ -174,7 +173,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); @@ -184,7 +183,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( vout0x3 = math_max_f32(vout0x3, voutput_min); vout1x3 = math_max_f32(vout1x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c index c5edd90fbb5..523d285b7e2 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( @@ -270,7 +269,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( w = (const float*) w + 8; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); @@ -288,7 +287,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( vout0x7 = math_max_f32(vout0x7, voutput_min); vout1x7 = math_max_f32(vout1x7, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c index 33b41fa79e5..e8855d10d4a 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( @@ -268,7 +267,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout2x0 = math_max_f32(vout2x0, voutput_min); @@ -286,7 +285,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( vout2x3 = math_max_f32(vout2x3, voutput_min); vout3x3 = math_max_f32(vout3x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout2x0 = math_min_f32(vout2x0, voutput_max); diff --git a/src/qs8-gemm/scalar.c.in b/src/qs8-gemm/scalar.c.in index fbe2c29b299..97285dd32cb 100644 --- a/src/qs8-gemm/scalar.c.in +++ b/src/qs8-gemm/scalar.c.in @@ -14,8 +14,6 @@ $if VARIANT == "LRINTF": #include "xnnpack/math.h" $if NR % 4 != 0: #include "xnnpack/unaligned.h" -$if DATATYPE in ["QB4_F16"]: - #include $# $INDENT = 0 @@ -279,7 +277,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}_ w = (const float*) w + ${NR * 2}; $if DATATYPE in ["QB4_F16"]: - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); $else: const float voutput_min = params->scalar.min; $for N in range(NR): @@ -287,7 +285,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}_ vout${M}x${N} = ${MAX_F32}(vout${M}x${N}, voutput_min); $if DATATYPE in ["QB4_F16"]: - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); $else: const float voutput_max = params->scalar.max; $for N in range(NR): diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c new file mode 100644 index 00000000000..0b7b42c16b3 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c @@ -0,0 +1,2319 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = params ? (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + ((int32_t*) out)[0] = b[0]; + ((int32_t*) out)[1] = b[1]; + ((int32_t*) out)[2] = b[2]; + ((int32_t*) out)[3] = b[3]; + ((int32_t*) out)[4] = b[4]; + ((int32_t*) out)[5] = b[5]; + ((int32_t*) out)[6] = b[6]; + ((int32_t*) out)[7] = b[7]; + ((int32_t*) out)[8] = b[8]; + ((int32_t*) out)[9] = b[9]; + ((int32_t*) out)[10] = b[10]; + ((int32_t*) out)[11] = b[11]; + ((int32_t*) out)[12] = b[12]; + ((int32_t*) out)[13] = b[13]; + ((int32_t*) out)[14] = b[14]; + ((int32_t*) out)[15] = b[15]; + b += 16; + } else { + ((int32_t*) out)[0] = 0; + ((int32_t*) out)[1] = 0; + ((int32_t*) out)[2] = 0; + ((int32_t*) out)[3] = 0; + ((int32_t*) out)[4] = 0; + ((int32_t*) out)[5] = 0; + ((int32_t*) out)[6] = 0; + ((int32_t*) out)[7] = 0; + ((int32_t*) out)[8] = 0; + ((int32_t*) out)[9] = 0; + ((int32_t*) out)[10] = 0; + ((int32_t*) out)[11] = 0; + ((int32_t*) out)[12] = 0; + ((int32_t*) out)[13] = 0; + ((int32_t*) out)[14] = 0; + ((int32_t*) out)[15] = 0; + } + out += 16 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + uint32_t ksum8 = 0; + uint32_t ksum9 = 0; + uint32_t ksum10 = 0; + uint32_t ksum11 = 0; + uint32_t ksum12 = 0; + uint32_t ksum13 = 0; + uint32_t ksum14 = 0; + uint32_t ksum15 = 0; + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + const int8_t v8x0 = w8[0]; + const int8_t v8x1 = w8[1]; + const int8_t v8x2 = w8[2]; + const int8_t v8x3 = w8[3]; + const int8_t v8x4 = w8[4]; + const int8_t v8x5 = w8[5]; + const int8_t v8x6 = w8[6]; + const int8_t v8x7 = w8[7]; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + out[64] = v8x0; + out[65] = v8x1; + out[66] = v8x2; + out[67] = v8x3; + out[68] = v8x4; + out[69] = v8x5; + out[70] = v8x6; + out[71] = v8x7; + w8 += 8; + const int8_t v9x0 = w9[0]; + const int8_t v9x1 = w9[1]; + const int8_t v9x2 = w9[2]; + const int8_t v9x3 = w9[3]; + const int8_t v9x4 = w9[4]; + const int8_t v9x5 = w9[5]; + const int8_t v9x6 = w9[6]; + const int8_t v9x7 = w9[7]; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + out[72] = v9x0; + out[73] = v9x1; + out[74] = v9x2; + out[75] = v9x3; + out[76] = v9x4; + out[77] = v9x5; + out[78] = v9x6; + out[79] = v9x7; + w9 += 8; + const int8_t v10x0 = w10[0]; + const int8_t v10x1 = w10[1]; + const int8_t v10x2 = w10[2]; + const int8_t v10x3 = w10[3]; + const int8_t v10x4 = w10[4]; + const int8_t v10x5 = w10[5]; + const int8_t v10x6 = w10[6]; + const int8_t v10x7 = w10[7]; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + out[80] = v10x0; + out[81] = v10x1; + out[82] = v10x2; + out[83] = v10x3; + out[84] = v10x4; + out[85] = v10x5; + out[86] = v10x6; + out[87] = v10x7; + w10 += 8; + const int8_t v11x0 = w11[0]; + const int8_t v11x1 = w11[1]; + const int8_t v11x2 = w11[2]; + const int8_t v11x3 = w11[3]; + const int8_t v11x4 = w11[4]; + const int8_t v11x5 = w11[5]; + const int8_t v11x6 = w11[6]; + const int8_t v11x7 = w11[7]; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + out[88] = v11x0; + out[89] = v11x1; + out[90] = v11x2; + out[91] = v11x3; + out[92] = v11x4; + out[93] = v11x5; + out[94] = v11x6; + out[95] = v11x7; + w11 += 8; + const int8_t v12x0 = w12[0]; + const int8_t v12x1 = w12[1]; + const int8_t v12x2 = w12[2]; + const int8_t v12x3 = w12[3]; + const int8_t v12x4 = w12[4]; + const int8_t v12x5 = w12[5]; + const int8_t v12x6 = w12[6]; + const int8_t v12x7 = w12[7]; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + out[96] = v12x0; + out[97] = v12x1; + out[98] = v12x2; + out[99] = v12x3; + out[100] = v12x4; + out[101] = v12x5; + out[102] = v12x6; + out[103] = v12x7; + w12 += 8; + const int8_t v13x0 = w13[0]; + const int8_t v13x1 = w13[1]; + const int8_t v13x2 = w13[2]; + const int8_t v13x3 = w13[3]; + const int8_t v13x4 = w13[4]; + const int8_t v13x5 = w13[5]; + const int8_t v13x6 = w13[6]; + const int8_t v13x7 = w13[7]; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + out[104] = v13x0; + out[105] = v13x1; + out[106] = v13x2; + out[107] = v13x3; + out[108] = v13x4; + out[109] = v13x5; + out[110] = v13x6; + out[111] = v13x7; + w13 += 8; + const int8_t v14x0 = w14[0]; + const int8_t v14x1 = w14[1]; + const int8_t v14x2 = w14[2]; + const int8_t v14x3 = w14[3]; + const int8_t v14x4 = w14[4]; + const int8_t v14x5 = w14[5]; + const int8_t v14x6 = w14[6]; + const int8_t v14x7 = w14[7]; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + out[112] = v14x0; + out[113] = v14x1; + out[114] = v14x2; + out[115] = v14x3; + out[116] = v14x4; + out[117] = v14x5; + out[118] = v14x6; + out[119] = v14x7; + w14 += 8; + const int8_t v15x0 = w15[0]; + const int8_t v15x1 = w15[1]; + const int8_t v15x2 = w15[2]; + const int8_t v15x3 = w15[3]; + const int8_t v15x4 = w15[4]; + const int8_t v15x5 = w15[5]; + const int8_t v15x6 = w15[6]; + const int8_t v15x7 = w15[7]; + ksum15 += (uint32_t) v15x0; + ksum15 += (uint32_t) v15x1; + ksum15 += (uint32_t) v15x2; + ksum15 += (uint32_t) v15x3; + ksum15 += (uint32_t) v15x4; + ksum15 += (uint32_t) v15x5; + ksum15 += (uint32_t) v15x6; + ksum15 += (uint32_t) v15x7; + out[120] = v15x0; + out[121] = v15x1; + out[122] = v15x2; + out[123] = v15x3; + out[124] = v15x4; + out[125] = v15x5; + out[126] = v15x6; + out[127] = v15x7; + w15 += 8; + out += 128; + } + + // KC remainder 1..KR-1 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + const int8_t v8x0 = 0 < k ? w8[0] : izp; + const int8_t v8x1 = 1 < k ? w8[1] : izp; + const int8_t v8x2 = 2 < k ? w8[2] : izp; + const int8_t v8x3 = 3 < k ? w8[3] : izp; + const int8_t v8x4 = 4 < k ? w8[4] : izp; + const int8_t v8x5 = 5 < k ? w8[5] : izp; + const int8_t v8x6 = 6 < k ? w8[6] : izp; + const int8_t v8x7 = 7 < k ? w8[7] : izp; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + if (0 < k) { + out[64] = v8x0; + } + if (1 < k) { + out[65] = v8x1; + } + if (2 < k) { + out[66] = v8x2; + } + if (3 < k) { + out[67] = v8x3; + } + if (4 < k) { + out[68] = v8x4; + } + if (5 < k) { + out[69] = v8x5; + } + if (6 < k) { + out[70] = v8x6; + } + if (7 < k) { + out[71] = v8x7; + } + w8 += 8; + const int8_t v9x0 = 0 < k ? w9[0] : izp; + const int8_t v9x1 = 1 < k ? w9[1] : izp; + const int8_t v9x2 = 2 < k ? w9[2] : izp; + const int8_t v9x3 = 3 < k ? w9[3] : izp; + const int8_t v9x4 = 4 < k ? w9[4] : izp; + const int8_t v9x5 = 5 < k ? w9[5] : izp; + const int8_t v9x6 = 6 < k ? w9[6] : izp; + const int8_t v9x7 = 7 < k ? w9[7] : izp; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + if (0 < k) { + out[72] = v9x0; + } + if (1 < k) { + out[73] = v9x1; + } + if (2 < k) { + out[74] = v9x2; + } + if (3 < k) { + out[75] = v9x3; + } + if (4 < k) { + out[76] = v9x4; + } + if (5 < k) { + out[77] = v9x5; + } + if (6 < k) { + out[78] = v9x6; + } + if (7 < k) { + out[79] = v9x7; + } + w9 += 8; + const int8_t v10x0 = 0 < k ? w10[0] : izp; + const int8_t v10x1 = 1 < k ? w10[1] : izp; + const int8_t v10x2 = 2 < k ? w10[2] : izp; + const int8_t v10x3 = 3 < k ? w10[3] : izp; + const int8_t v10x4 = 4 < k ? w10[4] : izp; + const int8_t v10x5 = 5 < k ? w10[5] : izp; + const int8_t v10x6 = 6 < k ? w10[6] : izp; + const int8_t v10x7 = 7 < k ? w10[7] : izp; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + if (0 < k) { + out[80] = v10x0; + } + if (1 < k) { + out[81] = v10x1; + } + if (2 < k) { + out[82] = v10x2; + } + if (3 < k) { + out[83] = v10x3; + } + if (4 < k) { + out[84] = v10x4; + } + if (5 < k) { + out[85] = v10x5; + } + if (6 < k) { + out[86] = v10x6; + } + if (7 < k) { + out[87] = v10x7; + } + w10 += 8; + const int8_t v11x0 = 0 < k ? w11[0] : izp; + const int8_t v11x1 = 1 < k ? w11[1] : izp; + const int8_t v11x2 = 2 < k ? w11[2] : izp; + const int8_t v11x3 = 3 < k ? w11[3] : izp; + const int8_t v11x4 = 4 < k ? w11[4] : izp; + const int8_t v11x5 = 5 < k ? w11[5] : izp; + const int8_t v11x6 = 6 < k ? w11[6] : izp; + const int8_t v11x7 = 7 < k ? w11[7] : izp; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + if (0 < k) { + out[88] = v11x0; + } + if (1 < k) { + out[89] = v11x1; + } + if (2 < k) { + out[90] = v11x2; + } + if (3 < k) { + out[91] = v11x3; + } + if (4 < k) { + out[92] = v11x4; + } + if (5 < k) { + out[93] = v11x5; + } + if (6 < k) { + out[94] = v11x6; + } + if (7 < k) { + out[95] = v11x7; + } + w11 += 8; + const int8_t v12x0 = 0 < k ? w12[0] : izp; + const int8_t v12x1 = 1 < k ? w12[1] : izp; + const int8_t v12x2 = 2 < k ? w12[2] : izp; + const int8_t v12x3 = 3 < k ? w12[3] : izp; + const int8_t v12x4 = 4 < k ? w12[4] : izp; + const int8_t v12x5 = 5 < k ? w12[5] : izp; + const int8_t v12x6 = 6 < k ? w12[6] : izp; + const int8_t v12x7 = 7 < k ? w12[7] : izp; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + if (0 < k) { + out[96] = v12x0; + } + if (1 < k) { + out[97] = v12x1; + } + if (2 < k) { + out[98] = v12x2; + } + if (3 < k) { + out[99] = v12x3; + } + if (4 < k) { + out[100] = v12x4; + } + if (5 < k) { + out[101] = v12x5; + } + if (6 < k) { + out[102] = v12x6; + } + if (7 < k) { + out[103] = v12x7; + } + w12 += 8; + const int8_t v13x0 = 0 < k ? w13[0] : izp; + const int8_t v13x1 = 1 < k ? w13[1] : izp; + const int8_t v13x2 = 2 < k ? w13[2] : izp; + const int8_t v13x3 = 3 < k ? w13[3] : izp; + const int8_t v13x4 = 4 < k ? w13[4] : izp; + const int8_t v13x5 = 5 < k ? w13[5] : izp; + const int8_t v13x6 = 6 < k ? w13[6] : izp; + const int8_t v13x7 = 7 < k ? w13[7] : izp; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + if (0 < k) { + out[104] = v13x0; + } + if (1 < k) { + out[105] = v13x1; + } + if (2 < k) { + out[106] = v13x2; + } + if (3 < k) { + out[107] = v13x3; + } + if (4 < k) { + out[108] = v13x4; + } + if (5 < k) { + out[109] = v13x5; + } + if (6 < k) { + out[110] = v13x6; + } + if (7 < k) { + out[111] = v13x7; + } + w13 += 8; + const int8_t v14x0 = 0 < k ? w14[0] : izp; + const int8_t v14x1 = 1 < k ? w14[1] : izp; + const int8_t v14x2 = 2 < k ? w14[2] : izp; + const int8_t v14x3 = 3 < k ? w14[3] : izp; + const int8_t v14x4 = 4 < k ? w14[4] : izp; + const int8_t v14x5 = 5 < k ? w14[5] : izp; + const int8_t v14x6 = 6 < k ? w14[6] : izp; + const int8_t v14x7 = 7 < k ? w14[7] : izp; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + if (0 < k) { + out[112] = v14x0; + } + if (1 < k) { + out[113] = v14x1; + } + if (2 < k) { + out[114] = v14x2; + } + if (3 < k) { + out[115] = v14x3; + } + if (4 < k) { + out[116] = v14x4; + } + if (5 < k) { + out[117] = v14x5; + } + if (6 < k) { + out[118] = v14x6; + } + if (7 < k) { + out[119] = v14x7; + } + w14 += 8; + const int8_t v15x0 = 0 < k ? w15[0] : izp; + const int8_t v15x1 = 1 < k ? w15[1] : izp; + const int8_t v15x2 = 2 < k ? w15[2] : izp; + const int8_t v15x3 = 3 < k ? w15[3] : izp; + const int8_t v15x4 = 4 < k ? w15[4] : izp; + const int8_t v15x5 = 5 < k ? w15[5] : izp; + const int8_t v15x6 = 6 < k ? w15[6] : izp; + const int8_t v15x7 = 7 < k ? w15[7] : izp; + ksum15 += (uint32_t) v15x0; + ksum15 += (uint32_t) v15x1; + ksum15 += (uint32_t) v15x2; + ksum15 += (uint32_t) v15x3; + ksum15 += (uint32_t) v15x4; + ksum15 += (uint32_t) v15x5; + ksum15 += (uint32_t) v15x6; + ksum15 += (uint32_t) v15x7; + if (0 < k) { + out[120] = v15x0; + } + if (1 < k) { + out[121] = v15x1; + } + if (2 < k) { + out[122] = v15x2; + } + if (3 < k) { + out[123] = v15x3; + } + if (4 < k) { + out[124] = v15x4; + } + if (5 < k) { + out[125] = v15x5; + } + if (6 < k) { + out[126] = v15x6; + } + if (7 < k) { + out[127] = v15x7; + } + w15 += 8; + out += 128; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + packed_b[8] -= ksum8 * izp; + packed_b[9] -= ksum9 * izp; + packed_b[10] -= ksum10 * izp; + packed_b[11] -= ksum11 * izp; + packed_b[12] -= ksum12 * izp; + packed_b[13] -= ksum13 * izp; + packed_b[14] -= ksum14 * izp; + packed_b[15] -= ksum15 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(int32_t); + + // NR remainder has less than 16 rows so last row is not loaded + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + uint32_t ksum8 = 0; + uint32_t ksum9 = 0; + uint32_t ksum10 = 0; + uint32_t ksum11 = 0; + uint32_t ksum12 = 0; + uint32_t ksum13 = 0; + uint32_t ksum14 = 0; + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + const int8_t v8x0 = w8[0]; + const int8_t v8x1 = w8[1]; + const int8_t v8x2 = w8[2]; + const int8_t v8x3 = w8[3]; + const int8_t v8x4 = w8[4]; + const int8_t v8x5 = w8[5]; + const int8_t v8x6 = w8[6]; + const int8_t v8x7 = w8[7]; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + out[64] = v8x0; + out[65] = v8x1; + out[66] = v8x2; + out[67] = v8x3; + out[68] = v8x4; + out[69] = v8x5; + out[70] = v8x6; + out[71] = v8x7; + w8 += 8; + const int8_t v9x0 = w9[0]; + const int8_t v9x1 = w9[1]; + const int8_t v9x2 = w9[2]; + const int8_t v9x3 = w9[3]; + const int8_t v9x4 = w9[4]; + const int8_t v9x5 = w9[5]; + const int8_t v9x6 = w9[6]; + const int8_t v9x7 = w9[7]; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + out[72] = v9x0; + out[73] = v9x1; + out[74] = v9x2; + out[75] = v9x3; + out[76] = v9x4; + out[77] = v9x5; + out[78] = v9x6; + out[79] = v9x7; + w9 += 8; + const int8_t v10x0 = w10[0]; + const int8_t v10x1 = w10[1]; + const int8_t v10x2 = w10[2]; + const int8_t v10x3 = w10[3]; + const int8_t v10x4 = w10[4]; + const int8_t v10x5 = w10[5]; + const int8_t v10x6 = w10[6]; + const int8_t v10x7 = w10[7]; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + out[80] = v10x0; + out[81] = v10x1; + out[82] = v10x2; + out[83] = v10x3; + out[84] = v10x4; + out[85] = v10x5; + out[86] = v10x6; + out[87] = v10x7; + w10 += 8; + const int8_t v11x0 = w11[0]; + const int8_t v11x1 = w11[1]; + const int8_t v11x2 = w11[2]; + const int8_t v11x3 = w11[3]; + const int8_t v11x4 = w11[4]; + const int8_t v11x5 = w11[5]; + const int8_t v11x6 = w11[6]; + const int8_t v11x7 = w11[7]; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + out[88] = v11x0; + out[89] = v11x1; + out[90] = v11x2; + out[91] = v11x3; + out[92] = v11x4; + out[93] = v11x5; + out[94] = v11x6; + out[95] = v11x7; + w11 += 8; + const int8_t v12x0 = w12[0]; + const int8_t v12x1 = w12[1]; + const int8_t v12x2 = w12[2]; + const int8_t v12x3 = w12[3]; + const int8_t v12x4 = w12[4]; + const int8_t v12x5 = w12[5]; + const int8_t v12x6 = w12[6]; + const int8_t v12x7 = w12[7]; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + out[96] = v12x0; + out[97] = v12x1; + out[98] = v12x2; + out[99] = v12x3; + out[100] = v12x4; + out[101] = v12x5; + out[102] = v12x6; + out[103] = v12x7; + w12 += 8; + const int8_t v13x0 = w13[0]; + const int8_t v13x1 = w13[1]; + const int8_t v13x2 = w13[2]; + const int8_t v13x3 = w13[3]; + const int8_t v13x4 = w13[4]; + const int8_t v13x5 = w13[5]; + const int8_t v13x6 = w13[6]; + const int8_t v13x7 = w13[7]; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + out[104] = v13x0; + out[105] = v13x1; + out[106] = v13x2; + out[107] = v13x3; + out[108] = v13x4; + out[109] = v13x5; + out[110] = v13x6; + out[111] = v13x7; + w13 += 8; + const int8_t v14x0 = w14[0]; + const int8_t v14x1 = w14[1]; + const int8_t v14x2 = w14[2]; + const int8_t v14x3 = w14[3]; + const int8_t v14x4 = w14[4]; + const int8_t v14x5 = w14[5]; + const int8_t v14x6 = w14[6]; + const int8_t v14x7 = w14[7]; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + out[112] = v14x0; + out[113] = v14x1; + out[114] = v14x2; + out[115] = v14x3; + out[116] = v14x4; + out[117] = v14x5; + out[118] = v14x6; + out[119] = v14x7; + w14 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + const int8_t v8x0 = 0 < k ? w8[0] : izp; + const int8_t v8x1 = 1 < k ? w8[1] : izp; + const int8_t v8x2 = 2 < k ? w8[2] : izp; + const int8_t v8x3 = 3 < k ? w8[3] : izp; + const int8_t v8x4 = 4 < k ? w8[4] : izp; + const int8_t v8x5 = 5 < k ? w8[5] : izp; + const int8_t v8x6 = 6 < k ? w8[6] : izp; + const int8_t v8x7 = 7 < k ? w8[7] : izp; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + if (0 < k) { + out[64] = v8x0; + } + if (1 < k) { + out[65] = v8x1; + } + if (2 < k) { + out[66] = v8x2; + } + if (3 < k) { + out[67] = v8x3; + } + if (4 < k) { + out[68] = v8x4; + } + if (5 < k) { + out[69] = v8x5; + } + if (6 < k) { + out[70] = v8x6; + } + if (7 < k) { + out[71] = v8x7; + } + w8 += 8; + const int8_t v9x0 = 0 < k ? w9[0] : izp; + const int8_t v9x1 = 1 < k ? w9[1] : izp; + const int8_t v9x2 = 2 < k ? w9[2] : izp; + const int8_t v9x3 = 3 < k ? w9[3] : izp; + const int8_t v9x4 = 4 < k ? w9[4] : izp; + const int8_t v9x5 = 5 < k ? w9[5] : izp; + const int8_t v9x6 = 6 < k ? w9[6] : izp; + const int8_t v9x7 = 7 < k ? w9[7] : izp; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + if (0 < k) { + out[72] = v9x0; + } + if (1 < k) { + out[73] = v9x1; + } + if (2 < k) { + out[74] = v9x2; + } + if (3 < k) { + out[75] = v9x3; + } + if (4 < k) { + out[76] = v9x4; + } + if (5 < k) { + out[77] = v9x5; + } + if (6 < k) { + out[78] = v9x6; + } + if (7 < k) { + out[79] = v9x7; + } + w9 += 8; + const int8_t v10x0 = 0 < k ? w10[0] : izp; + const int8_t v10x1 = 1 < k ? w10[1] : izp; + const int8_t v10x2 = 2 < k ? w10[2] : izp; + const int8_t v10x3 = 3 < k ? w10[3] : izp; + const int8_t v10x4 = 4 < k ? w10[4] : izp; + const int8_t v10x5 = 5 < k ? w10[5] : izp; + const int8_t v10x6 = 6 < k ? w10[6] : izp; + const int8_t v10x7 = 7 < k ? w10[7] : izp; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + if (0 < k) { + out[80] = v10x0; + } + if (1 < k) { + out[81] = v10x1; + } + if (2 < k) { + out[82] = v10x2; + } + if (3 < k) { + out[83] = v10x3; + } + if (4 < k) { + out[84] = v10x4; + } + if (5 < k) { + out[85] = v10x5; + } + if (6 < k) { + out[86] = v10x6; + } + if (7 < k) { + out[87] = v10x7; + } + w10 += 8; + const int8_t v11x0 = 0 < k ? w11[0] : izp; + const int8_t v11x1 = 1 < k ? w11[1] : izp; + const int8_t v11x2 = 2 < k ? w11[2] : izp; + const int8_t v11x3 = 3 < k ? w11[3] : izp; + const int8_t v11x4 = 4 < k ? w11[4] : izp; + const int8_t v11x5 = 5 < k ? w11[5] : izp; + const int8_t v11x6 = 6 < k ? w11[6] : izp; + const int8_t v11x7 = 7 < k ? w11[7] : izp; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + if (0 < k) { + out[88] = v11x0; + } + if (1 < k) { + out[89] = v11x1; + } + if (2 < k) { + out[90] = v11x2; + } + if (3 < k) { + out[91] = v11x3; + } + if (4 < k) { + out[92] = v11x4; + } + if (5 < k) { + out[93] = v11x5; + } + if (6 < k) { + out[94] = v11x6; + } + if (7 < k) { + out[95] = v11x7; + } + w11 += 8; + const int8_t v12x0 = 0 < k ? w12[0] : izp; + const int8_t v12x1 = 1 < k ? w12[1] : izp; + const int8_t v12x2 = 2 < k ? w12[2] : izp; + const int8_t v12x3 = 3 < k ? w12[3] : izp; + const int8_t v12x4 = 4 < k ? w12[4] : izp; + const int8_t v12x5 = 5 < k ? w12[5] : izp; + const int8_t v12x6 = 6 < k ? w12[6] : izp; + const int8_t v12x7 = 7 < k ? w12[7] : izp; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + if (0 < k) { + out[96] = v12x0; + } + if (1 < k) { + out[97] = v12x1; + } + if (2 < k) { + out[98] = v12x2; + } + if (3 < k) { + out[99] = v12x3; + } + if (4 < k) { + out[100] = v12x4; + } + if (5 < k) { + out[101] = v12x5; + } + if (6 < k) { + out[102] = v12x6; + } + if (7 < k) { + out[103] = v12x7; + } + w12 += 8; + const int8_t v13x0 = 0 < k ? w13[0] : izp; + const int8_t v13x1 = 1 < k ? w13[1] : izp; + const int8_t v13x2 = 2 < k ? w13[2] : izp; + const int8_t v13x3 = 3 < k ? w13[3] : izp; + const int8_t v13x4 = 4 < k ? w13[4] : izp; + const int8_t v13x5 = 5 < k ? w13[5] : izp; + const int8_t v13x6 = 6 < k ? w13[6] : izp; + const int8_t v13x7 = 7 < k ? w13[7] : izp; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + if (0 < k) { + out[104] = v13x0; + } + if (1 < k) { + out[105] = v13x1; + } + if (2 < k) { + out[106] = v13x2; + } + if (3 < k) { + out[107] = v13x3; + } + if (4 < k) { + out[108] = v13x4; + } + if (5 < k) { + out[109] = v13x5; + } + if (6 < k) { + out[110] = v13x6; + } + if (7 < k) { + out[111] = v13x7; + } + w13 += 8; + const int8_t v14x0 = 0 < k ? w14[0] : izp; + const int8_t v14x1 = 1 < k ? w14[1] : izp; + const int8_t v14x2 = 2 < k ? w14[2] : izp; + const int8_t v14x3 = 3 < k ? w14[3] : izp; + const int8_t v14x4 = 4 < k ? w14[4] : izp; + const int8_t v14x5 = 5 < k ? w14[5] : izp; + const int8_t v14x6 = 6 < k ? w14[6] : izp; + const int8_t v14x7 = 7 < k ? w14[7] : izp; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + if (0 < k) { + out[112] = v14x0; + } + if (1 < k) { + out[113] = v14x1; + } + if (2 < k) { + out[114] = v14x2; + } + if (3 < k) { + out[115] = v14x3; + } + if (4 < k) { + out[116] = v14x4; + } + if (5 < k) { + out[117] = v14x5; + } + if (6 < k) { + out[118] = v14x6; + } + if (7 < k) { + out[119] = v14x7; + } + w14 += 8; + out += 128; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + packed_b[8] -= ksum8 * izp; + packed_b[9] -= ksum9 * izp; + packed_b[10] -= ksum10 * izp; + packed_b[11] -= ksum11 * izp; + packed_b[12] -= ksum12 * izp; + packed_b[13] -= ksum13 * izp; + packed_b[14] -= ksum14 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c new file mode 100644 index 00000000000..45c63d2b268 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c @@ -0,0 +1,444 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnniint8.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if defined(__x86_64__) + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +#endif // defined(__x86_64__) +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c new file mode 100644 index 00000000000..d954e4de632 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c @@ -0,0 +1,1175 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = params ? (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + ((int32_t*) out)[0] = b[0]; + ((int32_t*) out)[1] = b[1]; + ((int32_t*) out)[2] = b[2]; + ((int32_t*) out)[3] = b[3]; + ((int32_t*) out)[4] = b[4]; + ((int32_t*) out)[5] = b[5]; + ((int32_t*) out)[6] = b[6]; + ((int32_t*) out)[7] = b[7]; + b += 8; + } else { + ((int32_t*) out)[0] = 0; + ((int32_t*) out)[1] = 0; + ((int32_t*) out)[2] = 0; + ((int32_t*) out)[3] = 0; + ((int32_t*) out)[4] = 0; + ((int32_t*) out)[5] = 0; + ((int32_t*) out)[6] = 0; + ((int32_t*) out)[7] = 0; + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + out += 64; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(int32_t); + + // NR remainder has less than 8 rows so last row is not loaded + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + out += 64; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index d0d55e50e31..570273a66cd 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -21,6 +21,14 @@ XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c4__scalar, 16, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x32c4__scalar, 32, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar, 64, 4, 1, 4, 1) +XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar, 8, 8, 1, 8, 1) +XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 8, 1) + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if XNN_ENABLE_AVXVNNIINT8 && XNN_ARCH_X86_64 +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnniint8, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8, 8, 8, 1, 8, 1) +#endif + #ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_QS8_UKERNEL_WITH_PARAMS diff --git a/src/subgraph.c b/src/subgraph.c index d227add1624..4bc18c1b15b 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -13,11 +13,11 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" +#include "xnnpack/fp16.h" #include "xnnpack/hardware-config.h" #include "xnnpack/log.h" #include "xnnpack/math.h" diff --git a/src/subgraph/static-constant-pad.c b/src/subgraph/static-constant-pad.c index 5bad5e90fb4..ab75daff4e0 100644 --- a/src/subgraph/static-constant-pad.c +++ b/src/subgraph/static-constant-pad.c @@ -9,9 +9,9 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/common.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" diff --git a/src/x8-packw/kr-avxvnniint8.c.in b/src/x8-packw/kr-avxvnniint8.c.in new file mode 100644 index 00000000000..fc6fb0a2f77 --- /dev/null +++ b/src/x8-packw/kr-avxvnniint8.c.in @@ -0,0 +1,401 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR == 8 +$assert KR == 8 +$assert TYPE in ["int8_t"] + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +$BITS = {"int8_t": 8}[TYPE] +$BTYPE = {"int8_t": "uint32_t"}[TYPE] +$WTYPE = {"int8_t": "int8_t"}[TYPE] +void xnn_qs${BITS}_packw_gemm_goi_ukernel_x${NR}c${KR}__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const ${WTYPE}* weights, + $if BITS == 8: + const int32_t* bias, + $else: + const ${WTYPE}* bias, + const void* scale, + ${WTYPE}* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if defined(__x86_64__) + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + $if BITS == 8: + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += ${NR}; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + $if BTYPE == TYPE: + out += ${NR}; + $else: + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w${NR-1}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = *b++; + $else: + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } else { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = 0; + $else: + *((${BTYPE}*) out) = 0; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } + $if BTYPE == TYPE: + out += (${NR} - n); + $else: + out += (${NR} - n) * sizeof(${BTYPE}); + + $if NR > 2: + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + } + + $if BITS == 8: + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +#endif // defined(__x86_64__) +} diff --git a/src/xnnpack/fp16.h b/src/xnnpack/fp16.h new file mode 100644 index 00000000000..dc3b47aa8dc --- /dev/null +++ b/src/xnnpack/fp16.h @@ -0,0 +1,179 @@ +#ifndef THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ +#define THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ + +#include + +// This file is an excerpt from https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h, +// including only the minimal functionality we need in XNNPACK. This works around some issues +// that we haven't been able to fix upstream (https://github.com/Maratyszcza/FP16/pull/32). See also: +// - https://github.com/microsoft/onnxruntime/pull/22294/files +// - https://github.com/google/XNNPACK/issues/6989 +// We also don't need a lot of the functionality in the upstream library. + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32 = { w }; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32 = { f }; + return fp32.as_bits; +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + const uint32_t w = fp32_to_bits(f); + const float abs_f = fp32_from_bits(w & UINT32_C(0x7FFFFFFF)); + float base = (abs_f * scale_to_inf) * scale_to_zero; + + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + + +#endif // THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ diff --git a/src/xnnpack/math.h b/src/xnnpack/math.h index dab6e037853..893e038a312 100644 --- a/src/xnnpack/math.h +++ b/src/xnnpack/math.h @@ -18,8 +18,8 @@ #include // For _rotl. #endif -#include #include "xnnpack/common.h" +#include "xnnpack/fp16.h" // stdlib.h from Windows 10 SDK defines min & max macros. // Undefine them before defining the corresponding functions. diff --git a/src/xnnpack/quantization.h b/src/xnnpack/quantization.h index 95f832bada5..0884ba41a2b 100644 --- a/src/xnnpack/quantization.h +++ b/src/xnnpack/quantization.h @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack/math.h" #include "xnnpack/microparams.h" diff --git a/src/xnnpack/simd/f16-scalar.h b/src/xnnpack/simd/f16-scalar.h index c24076b282b..e95425c5b9b 100644 --- a/src/xnnpack/simd/f16-scalar.h +++ b/src/xnnpack/simd/f16-scalar.h @@ -11,8 +11,8 @@ #include #include -#include #include "xnnpack/common.h" +#include "xnnpack/fp16.h" // SIMD vector type for f16 using SCALAR. typedef uint16_t xnn_simd_f16_t; diff --git a/test/BUILD.bazel b/test/BUILD.bazel index b36d2b0c0f8..603357df6d6 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -26,11 +26,11 @@ load( MICROKERNEL_TEST_DEPS = [ ":next_prime", ":replicable_random_device", - "@FP16", "//:aligned_allocator", "//:all_microkernels", "//:allocator", "//:common", + "//:fp16", "//:isa_checks", "//:math", "//:memory", @@ -46,12 +46,12 @@ MICROKERNEL_TEST_DEPS = [ OPERATOR_TEST_DEPS = [ ":replicable_random_device", - "@FP16", "@pthreadpool", "//:aligned_allocator", "//:allocator", "//:cache", "//:common", + "//:fp16", "//:internal", "//:math", "//:microkernel_configs", @@ -189,8 +189,8 @@ sh_test( copts = xnnpack_simd_copts_for_arch(arch), deps = [ ":replicable_random_device", - "@FP16", "//:common", + "//:fp16", "//:isa_checks", "//:microkernels_h", ], @@ -1769,7 +1769,6 @@ xnnpack_cxx_library( deps = [ ":replicable_random_device", ":subgraph_unary_tester", - "@FP16", "//:XNNPACK", "//:math", "//:node_type", @@ -1885,7 +1884,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1915,7 +1913,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1932,7 +1929,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -1953,7 +1949,6 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1974,7 +1969,6 @@ xnnpack_unit_test( shard_count = 5, deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operator_utils", @@ -1991,7 +1985,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2007,7 +2000,6 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2026,7 +2018,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2046,7 +2037,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2068,7 +2058,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2086,7 +2075,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2104,7 +2092,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2121,7 +2108,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2138,7 +2124,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operator_utils", @@ -2155,7 +2140,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2185,7 +2169,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2201,7 +2184,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2230,7 +2212,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2348,7 +2329,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:allocation_type", "//:allocator", diff --git a/test/abs.cc b/test/abs.cc index ad1e8401293..7aaabe56139 100644 --- a/test/abs.cc +++ b/test/abs.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/average-pooling-2d.cc b/test/average-pooling-2d.cc index 4602f2aded5..8e9d2b58c06 100644 --- a/test/average-pooling-2d.cc +++ b/test/average-pooling-2d.cc @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/avgpool-microkernel-tester.h b/test/avgpool-microkernel-tester.h index 35db8fe1bb4..68e4bd4fb39 100644 --- a/test/avgpool-microkernel-tester.h +++ b/test/avgpool-microkernel-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/bankers-rounding.cc b/test/bankers-rounding.cc index 0021bdad455..fe4ab88f153 100644 --- a/test/bankers-rounding.cc +++ b/test/bankers-rounding.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/batch-matrix-multiply.cc b/test/batch-matrix-multiply.cc index 55fbe9eeb3e..53bb93df183 100644 --- a/test/batch-matrix-multiply.cc +++ b/test/batch-matrix-multiply.cc @@ -20,7 +20,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/ceiling.cc b/test/ceiling.cc index 01168b81e60..16c1f9c4ce6 100644 --- a/test/ceiling.cc +++ b/test/ceiling.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/clamp.cc b/test/clamp.cc index f9c87f57769..e5db94311f0 100644 --- a/test/clamp.cc +++ b/test/clamp.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate2.cc b/test/concatenate2.cc index 06c6c351860..245f114fa3a 100644 --- a/test/concatenate2.cc +++ b/test/concatenate2.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate3.cc b/test/concatenate3.cc index 9a7888effff..cb4f56a4b33 100644 --- a/test/concatenate3.cc +++ b/test/concatenate3.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate4.cc b/test/concatenate4.cc index 0d6ac91a838..992ccf9668b 100644 --- a/test/concatenate4.cc +++ b/test/concatenate4.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate5.cc b/test/concatenate5.cc index 8dfea193ce7..2934812315c 100644 --- a/test/concatenate5.cc +++ b/test/concatenate5.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/conv-hwc2chw-microkernel-tester.h b/test/conv-hwc2chw-microkernel-tester.h index d354992bab4..eaca4b86116 100644 --- a/test/conv-hwc2chw-microkernel-tester.h +++ b/test/conv-hwc2chw-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/convert-operator-tester.h b/test/convert-operator-tester.h index 78724161ed6..9a9641a32d9 100644 --- a/test/convert-operator-tester.h +++ b/test/convert-operator-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" diff --git a/test/convert.cc b/test/convert.cc index 8b7933b1b99..5c418403e18 100644 --- a/test/convert.cc +++ b/test/convert.cc @@ -12,7 +12,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/convolution-2d.cc b/test/convolution-2d.cc index a029865d309..8e85f6cb916 100644 --- a/test/convolution-2d.cc +++ b/test/convolution-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h index ad155b26759..0a59b20b5b8 100644 --- a/test/convolution-operator-tester.h +++ b/test/convolution-operator-tester.h @@ -21,7 +21,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/cache.h" diff --git a/test/copy.cc b/test/copy.cc index 8683530503b..a19cb3d34ba 100644 --- a/test/copy.cc +++ b/test/copy.cc @@ -11,7 +11,6 @@ #include // For std::unique_ptr. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/deconvolution-2d.cc b/test/deconvolution-2d.cc index 1d7d53bc314..6e4f6cda19c 100644 --- a/test/deconvolution-2d.cc +++ b/test/deconvolution-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" diff --git a/test/deconvolution-operator-tester.h b/test/deconvolution-operator-tester.h index 21d108ed9a4..c74e75b6650 100644 --- a/test/deconvolution-operator-tester.h +++ b/test/deconvolution-operator-tester.h @@ -23,7 +23,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" diff --git a/test/depth-to-space-2d.cc b/test/depth-to-space-2d.cc index 924f2e18999..f20ee1de3a0 100644 --- a/test/depth-to-space-2d.cc +++ b/test/depth-to-space-2d.cc @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/depthwise-convolution-2d.cc b/test/depthwise-convolution-2d.cc index 2f008078792..923c0e67762 100644 --- a/test/depthwise-convolution-2d.cc +++ b/test/depthwise-convolution-2d.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/dwconv-microkernel-tester.cc b/test/dwconv-microkernel-tester.cc index d29e22e7572..942e5d284d2 100644 --- a/test/dwconv-microkernel-tester.cc +++ b/test/dwconv-microkernel-tester.cc @@ -21,7 +21,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/dwconv2d-microkernel-tester.h b/test/dwconv2d-microkernel-tester.h index 2af4605c7bc..15a8438306c 100644 --- a/test/dwconv2d-microkernel-tester.h +++ b/test/dwconv2d-microkernel-tester.h @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/elu.cc b/test/elu.cc index 1ad5b23fbd6..ccb49d61d68 100644 --- a/test/elu.cc +++ b/test/elu.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split2.cc b/test/even-split2.cc index 221b5f00f09..ab00d5f777d 100644 --- a/test/even-split2.cc +++ b/test/even-split2.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split3.cc b/test/even-split3.cc index 99979aa0965..cd697a325ab 100644 --- a/test/even-split3.cc +++ b/test/even-split3.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split4.cc b/test/even-split4.cc index 13f872aaa0b..22d8ab52f82 100644 --- a/test/even-split4.cc +++ b/test/even-split4.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/f16-simd-scalar.cc b/test/f16-simd-scalar.cc index d5d100ede7a..3fa9cda6100 100644 --- a/test/f16-simd-scalar.cc +++ b/test/f16-simd-scalar.cc @@ -17,9 +17,9 @@ #include #include -#include #include "xnnpack/isa-checks.h" #include "xnnpack/simd/f16-scalar.h" +#include "xnnpack/fp16.h" #include "replicable_random_device.h" namespace xnnpack { diff --git a/test/f16-simd.cc.in b/test/f16-simd.cc.in index c6de215ce7f..84ebd2d78fb 100644 --- a/test/f16-simd.cc.in +++ b/test/f16-simd.cc.in @@ -19,9 +19,9 @@ $if ARCH_MACRO: #include #include -#include #include "xnnpack/isa-checks.h" #include "xnnpack/simd/f16-${ARCH}.h" +#include "xnnpack/fp16.h" #include "replicable_random_device.h" namespace xnnpack { diff --git a/test/floor.cc b/test/floor.cc index 0f62c175ffd..790b668d726 100644 --- a/test/floor.cc +++ b/test/floor.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h index a84bc876da7..c7e5695ca42 100644 --- a/test/fully-connected-operator-tester.h +++ b/test/fully-connected-operator-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" diff --git a/test/fully-connected.cc b/test/fully-connected.cc index 54624914ce0..a8a11097c10 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/gavgpool-cw-microkernel-tester.h b/test/gavgpool-cw-microkernel-tester.h index 077c7c0775b..fae2cb8e9d2 100644 --- a/test/gavgpool-cw-microkernel-tester.h +++ b/test/gavgpool-cw-microkernel-tester.h @@ -15,8 +15,8 @@ #include #include -#include #include "xnnpack.h" +#include "xnnpack/fp16.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" #include "replicable_random_device.h" diff --git a/test/gavgpool-microkernel-tester.h b/test/gavgpool-microkernel-tester.h index 78a6db9604c..723bf879e94 100644 --- a/test/gavgpool-microkernel-tester.h +++ b/test/gavgpool-microkernel-tester.h @@ -19,7 +19,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 2e5bf67566e..8d3d90b157d 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -14,8 +14,6 @@ #include #include -#include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/aligned-allocator.h" diff --git a/test/global-average-pooling-1d.cc b/test/global-average-pooling-1d.cc index 7cd48566976..c7a472835e4 100644 --- a/test/global-average-pooling-1d.cc +++ b/test/global-average-pooling-1d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-average-pooling-2d.cc b/test/global-average-pooling-2d.cc index 9a51ac9aa16..7f83beb8fb3 100644 --- a/test/global-average-pooling-2d.cc +++ b/test/global-average-pooling-2d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-sum-pooling-1d.cc b/test/global-sum-pooling-1d.cc index fcf429581e3..45785d8fa37 100644 --- a/test/global-sum-pooling-1d.cc +++ b/test/global-sum-pooling-1d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-sum-pooling-2d.cc b/test/global-sum-pooling-2d.cc index cbe47848380..8d618494db5 100644 --- a/test/global-sum-pooling-2d.cc +++ b/test/global-sum-pooling-2d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/hardswish.cc b/test/hardswish.cc index 515e9e62faa..728947595f4 100644 --- a/test/hardswish.cc +++ b/test/hardswish.cc @@ -13,7 +13,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/ibilinear-microkernel-tester.h b/test/ibilinear-microkernel-tester.h index f9fdecdbf8f..c13e2a7f409 100644 --- a/test/ibilinear-microkernel-tester.h +++ b/test/ibilinear-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/math.h" diff --git a/test/leaky-relu.cc b/test/leaky-relu.cc index dfba1ccce39..0b9d2180442 100644 --- a/test/leaky-relu.cc +++ b/test/leaky-relu.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/max-pooling-2d.cc b/test/max-pooling-2d.cc index 241ec93ab46..2c89bb58efc 100644 --- a/test/max-pooling-2d.cc +++ b/test/max-pooling-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" diff --git a/test/maxpool-microkernel-tester.h b/test/maxpool-microkernel-tester.h index faedb72b52b..7587b033cd5 100644 --- a/test/maxpool-microkernel-tester.h +++ b/test/maxpool-microkernel-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/negate.cc b/test/negate.cc index 9ae0008b3f9..a36ef815844 100644 --- a/test/negate.cc +++ b/test/negate.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/packing.cc b/test/packing.cc index eeebc0946d6..4c6bd1910bd 100644 --- a/test/packing.cc +++ b/test/packing.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack/math.h" #include "xnnpack/microkernel-utils.h" #include "xnnpack/microparams-init.h" diff --git a/test/prelu-microkernel-tester.h b/test/prelu-microkernel-tester.h index 10ec6e4e1f5..b50bab36775 100644 --- a/test/prelu-microkernel-tester.h +++ b/test/prelu-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/prelu.cc b/test/prelu.cc index 5f4a65daf61..698eee99a34 100644 --- a/test/prelu.cc +++ b/test/prelu.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/raddstoreexpminusmax-microkernel-tester.h b/test/raddstoreexpminusmax-microkernel-tester.h index 793e161a8ed..4ea512f71a3 100644 --- a/test/raddstoreexpminusmax-microkernel-tester.h +++ b/test/raddstoreexpminusmax-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/rdsum-microkernel-tester.h b/test/rdsum-microkernel-tester.h index dc08fbf8700..0dc57804777 100644 --- a/test/rdsum-microkernel-tester.h +++ b/test/rdsum-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/reciprocal-square-root.cc b/test/reciprocal-square-root.cc index 8f7c6e1baae..b5734767615 100644 --- a/test/reciprocal-square-root.cc +++ b/test/reciprocal-square-root.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/rsum-microkernel-tester.h b/test/rsum-microkernel-tester.h index dee79bb9f0c..9f2c94b9419 100644 --- a/test/rsum-microkernel-tester.h +++ b/test/rsum-microkernel-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/scaled-dot-product-attention.cc b/test/scaled-dot-product-attention.cc index 0d6fd61e381..c8782608783 100644 --- a/test/scaled-dot-product-attention.cc +++ b/test/scaled-dot-product-attention.cc @@ -17,7 +17,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/sigmoid.cc b/test/sigmoid.cc index 82ffa4db623..df1935063ef 100644 --- a/test/sigmoid.cc +++ b/test/sigmoid.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/softmax.cc b/test/softmax.cc index b17134654a2..458638b77a9 100644 --- a/test/softmax.cc +++ b/test/softmax.cc @@ -12,7 +12,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/space-to-depth-2d.cc b/test/space-to-depth-2d.cc index 9fc996e3777..66869423ad6 100644 --- a/test/space-to-depth-2d.cc +++ b/test/space-to-depth-2d.cc @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/spmm-microkernel-tester.h b/test/spmm-microkernel-tester.h index 4ffa7cf6f89..bd8a115cd2f 100644 --- a/test/spmm-microkernel-tester.h +++ b/test/spmm-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/square-root.cc b/test/square-root.cc index 1ed486e0db3..23ef94f803a 100644 --- a/test/square-root.cc +++ b/test/square-root.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/square.cc b/test/square.cc index f183bae5603..4fbc26b979d 100644 --- a/test/square.cc +++ b/test/square.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-constant-pad.cc b/test/static-constant-pad.cc index c9ca1080777..77f0d813d7f 100644 --- a/test/static-constant-pad.cc +++ b/test/static-constant-pad.cc @@ -13,7 +13,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" diff --git a/test/static-expand-dims.cc b/test/static-expand-dims.cc index d608b5dddcd..5fb68440678 100644 --- a/test/static-expand-dims.cc +++ b/test/static-expand-dims.cc @@ -15,7 +15,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-mean.cc b/test/static-mean.cc index 92b7a9d309d..a55f63c4acf 100644 --- a/test/static-mean.cc +++ b/test/static-mean.cc @@ -16,7 +16,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/static-reshape.cc b/test/static-reshape.cc index 0745cbd2273..6107acbbc04 100644 --- a/test/static-reshape.cc +++ b/test/static-reshape.cc @@ -15,7 +15,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-resize-bilinear-2d.cc b/test/static-resize-bilinear-2d.cc index 250c048995b..6ea0dabddd2 100644 --- a/test/static-resize-bilinear-2d.cc +++ b/test/static-resize-bilinear-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-slice.cc b/test/static-slice.cc index 81c25ea0d0a..e56e239a2a4 100644 --- a/test/static-slice.cc +++ b/test/static-slice.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-transpose.cc b/test/static-transpose.cc index 52dc729c22e..d5a92cd7cbe 100644 --- a/test/static-transpose.cc +++ b/test/static-transpose.cc @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/subgraph-fp16.cc b/test/subgraph-fp16.cc index 7653908077b..ae25e6e6d86 100644 --- a/test/subgraph-fp16.cc +++ b/test/subgraph-fp16.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/node-type.h" diff --git a/test/tanh-operator-tester.h b/test/tanh-operator-tester.h index f43e5d7e997..755a6def164 100644 --- a/test/tanh-operator-tester.h +++ b/test/tanh-operator-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "replicable_random_device.h" diff --git a/test/tanh.cc b/test/tanh.cc index 2d01af0c534..bd0cf712cb3 100644 --- a/test/tanh.cc +++ b/test/tanh.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/unary-operator-tester.cc b/test/unary-operator-tester.cc index 74912831187..f6d6545eb45 100644 --- a/test/unary-operator-tester.cc +++ b/test/unary-operator-tester.cc @@ -19,7 +19,6 @@ #include #include -#include #include "xnnpack.h" #include "replicable_random_device.h" diff --git a/test/vbinary-microkernel-tester.cc b/test/vbinary-microkernel-tester.cc index 00515ad5713..f4a48916660 100644 --- a/test/vbinary-microkernel-tester.cc +++ b/test/vbinary-microkernel-tester.cc @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams-init.h" diff --git a/test/vcmul-microkernel-tester.h b/test/vcmul-microkernel-tester.h index 4c7461580b9..4d875874c38 100644 --- a/test/vcmul-microkernel-tester.h +++ b/test/vcmul-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/isa-checks.h" #include "xnnpack/microfnptr.h" diff --git a/test/vcvt-microkernel-tester.cc b/test/vcvt-microkernel-tester.cc index e17d555edc5..44129ffac4c 100644 --- a/test/vcvt-microkernel-tester.cc +++ b/test/vcvt-microkernel-tester.cc @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" diff --git a/test/vmulcaddc-microkernel-tester.h b/test/vmulcaddc-microkernel-tester.h index c5f15d12aaf..13afb0cb58e 100644 --- a/test/vmulcaddc-microkernel-tester.h +++ b/test/vmulcaddc-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/vunary-microkernel-tester.cc b/test/vunary-microkernel-tester.cc index ffa6e29a5fe..07a7ad6e91b 100644 --- a/test/vunary-microkernel-tester.cc +++ b/test/vunary-microkernel-tester.cc @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/common.h" #include "xnnpack/microfnptr.h" diff --git a/test/vunary-microkernel-tester.h b/test/vunary-microkernel-tester.h index 8293d2769c9..a65b6646570 100644 --- a/test/vunary-microkernel-tester.h +++ b/test/vunary-microkernel-tester.h @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "replicable_random_device.h"