From d79f803ea5c15fae18f5929f767bee24b4d9ada6 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Thu, 3 Oct 2024 21:20:19 -0700 Subject: [PATCH] Remove dependency on fp16 This dependency is causing issues on some platforms. We have a workaround: https://github.com/Maratyszcza/fp16 that some users have implemented (e.g. https://github.com/microsoft/onnxruntime/pull/22294/files), but it doesn't make sense to hack this workaround into everything that depends on XNNPACK. This PR pulls the small part of the fp16 library we actually need in XNNPACK, and removes the dependency. FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/google/XNNPACK/pull/7128 from imaginationtech:img_patch30_f32_f16_vcvt 2f40edd2509c5dce91c93df3b12d0bd6da59f051 PiperOrigin-RevId: 682156616 --- BUILD.bazel | 17 +- CMakeLists.txt | 84 +- WORKSPACE | 13 - bench/BUILD.bazel | 2 - bench/gemm-benchmark.cc | 1 - build_params.bzl | 2 - cmake/DownloadFP16.cmake | 28 - cmake/gen/avxvnniint8_microkernels.cmake | 3 +- cmake/gen/rvvfp16arith_microkernels.cmake | 6 +- cmake/gen/scalar_microkernels.cmake | 2 + gen/avxvnniint8_microkernels.bzl | 1 + gen/rvvfp16arith_microkernels.bzl | 4 + gen/scalar_microkernels.bzl | 2 + scripts/generate-f32-f16-vcvt.sh | 6 + scripts/generate-x8-packw.sh | 7 + scripts/genxnn | 73 + src/configs/unary-elementwise-config.c | 2 +- .../gen/f16-qs8-vcvt-scalar-fmagic-u1.c | 1 - .../gen/f16-qs8-vcvt-scalar-fmagic-u2.c | 1 - .../gen/f16-qs8-vcvt-scalar-fmagic-u3.c | 1 - .../gen/f16-qs8-vcvt-scalar-fmagic-u4.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u1.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u2.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u3.c | 1 - .../gen/f16-qs8-vcvt-scalar-imagic-u4.c | 1 - src/f32-f16-vcvt/f32-f16-vcvt.h | 7 + .../gen/f32-f16-vcvt-rvvfp16arith-u1v.c | 40 + .../gen/f32-f16-vcvt-rvvfp16arith-u2v.c | 40 + .../gen/f32-f16-vcvt-rvvfp16arith-u4v.c | 40 + .../gen/f32-f16-vcvt-rvvfp16arith-u8v.c | 40 + src/f32-f16-vcvt/rvvfp16arith.c.in | 38 + src/f32-qs8-vcvt/scalar-fmagic.c.in | 2 - src/f32-qs8-vcvt/scalar-imagic.c.in | 2 - src/f32-rminmax/scalar.c.in | 4 +- src/operators/average-pooling-nhwc.c | 1 - src/operators/convolution-nchw.c | 1 - src/operators/convolution-nhwc.c | 1 - src/operators/deconvolution-nhwc.c | 1 - src/operators/dynamic-fully-connected-nc.c | 1 - src/operators/global-average-pooling-ncw.c | 2 +- src/operators/global-average-pooling-nwc.c | 2 +- src/operators/max-pooling-nhwc.c | 1 - .../scaled-dot-product-attention-nhtc.c | 1 - src/operators/softmax-nc.c | 1 - src/operators/unary-elementwise-nc.c | 1 - src/packing.cc | 1 - .../gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c | 5 +- .../gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c | 5 +- src/qs8-gemm/scalar.c.in | 6 +- .../gen/qs8-packw-x16c8-gemm-goi-scalar.c | 2319 +++++++++++++++++ .../gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c | 444 ++++ .../gen/qs8-packw-x8c8-gemm-goi-scalar.c | 1175 +++++++++ src/qs8-packw/qs8-packw.h | 8 + src/subgraph.c | 2 +- src/subgraph/static-constant-pad.c | 2 +- src/x8-packw/kr-avxvnniint8.c.in | 401 +++ src/xnnpack/fp16.h | 179 ++ src/xnnpack/math.h | 2 +- src/xnnpack/quantization.h | 1 - src/xnnpack/simd/f16-scalar.h | 2 +- test/BUILD.bazel | 26 +- test/abs.cc | 1 - test/average-pooling-2d.cc | 1 - test/avgpool-microkernel-tester.h | 1 - test/bankers-rounding.cc | 1 - test/batch-matrix-multiply.cc | 1 - test/ceiling.cc | 1 - test/clamp.cc | 1 - test/concatenate2.cc | 1 - test/concatenate3.cc | 1 - test/concatenate4.cc | 1 - test/concatenate5.cc | 1 - test/conv-hwc2chw-microkernel-tester.h | 1 - test/convert-operator-tester.h | 1 - test/convert.cc | 1 - test/convolution-2d.cc | 1 - test/convolution-operator-tester.h | 1 - test/copy.cc | 1 - test/deconvolution-2d.cc | 1 - test/deconvolution-operator-tester.h | 1 - test/depth-to-space-2d.cc | 1 - test/depthwise-convolution-2d.cc | 1 - test/dwconv-microkernel-tester.cc | 1 - test/dwconv2d-microkernel-tester.h | 1 - test/elu.cc | 1 - test/even-split2.cc | 1 - test/even-split3.cc | 1 - test/even-split4.cc | 1 - test/f16-simd-scalar.cc | 2 +- test/f16-simd.cc.in | 2 +- test/floor.cc | 1 - test/fully-connected-operator-tester.h | 1 - test/fully-connected.cc | 1 - test/gavgpool-cw-microkernel-tester.h | 2 +- test/gavgpool-microkernel-tester.h | 1 - test/gemm-microkernel-tester.cc | 2 - test/global-average-pooling-1d.cc | 1 - test/global-average-pooling-2d.cc | 1 - test/global-sum-pooling-1d.cc | 1 - test/global-sum-pooling-2d.cc | 1 - test/hardswish.cc | 1 - test/ibilinear-microkernel-tester.h | 1 - test/leaky-relu.cc | 1 - test/max-pooling-2d.cc | 1 - test/maxpool-microkernel-tester.h | 1 - test/negate.cc | 1 - test/packing.cc | 1 - test/prelu-microkernel-tester.h | 1 - test/prelu.cc | 1 - .../raddstoreexpminusmax-microkernel-tester.h | 1 - test/rdsum-microkernel-tester.h | 1 - test/reciprocal-square-root.cc | 1 - test/rsum-microkernel-tester.h | 1 - test/scaled-dot-product-attention.cc | 1 - test/sigmoid.cc | 1 - test/softmax.cc | 1 - test/space-to-depth-2d.cc | 1 - test/spmm-microkernel-tester.h | 1 - test/square-root.cc | 1 - test/square.cc | 1 - test/static-constant-pad.cc | 1 - test/static-expand-dims.cc | 1 - test/static-mean.cc | 1 - test/static-reshape.cc | 1 - test/static-resize-bilinear-2d.cc | 1 - test/static-slice.cc | 1 - test/static-transpose.cc | 1 - test/subgraph-fp16.cc | 1 - test/tanh-operator-tester.h | 1 - test/tanh.cc | 1 - test/unary-operator-tester.cc | 1 - test/vbinary-microkernel-tester.cc | 1 - test/vcmul-microkernel-tester.h | 1 - test/vcvt-microkernel-tester.cc | 1 - test/vmulcaddc-microkernel-tester.h | 1 - test/vunary-microkernel-tester.cc | 1 - test/vunary-microkernel-tester.h | 1 - 142 files changed, 4884 insertions(+), 286 deletions(-) delete mode 100644 cmake/DownloadFP16.cmake create mode 100755 scripts/genxnn create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c create mode 100644 src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c create mode 100644 src/f32-f16-vcvt/rvvfp16arith.c.in create mode 100644 src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c create mode 100644 src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c create mode 100644 src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c create mode 100644 src/x8-packw/kr-avxvnniint8.c.in create mode 100644 src/xnnpack/fp16.h diff --git a/BUILD.bazel b/BUILD.bazel index baa9cec726b2..e54d106914ba 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -327,7 +327,6 @@ xnnpack_cc_library( ":microparams", ":mutex", ":xnnpack_h", - "@FP16", ], ) @@ -428,7 +427,6 @@ xnnpack_cc_library( ":packing", ":prod_microkernels", ":xnnpack_h", - "@FP16", ] + select({ ":cpuinfo_enabled": ["@cpuinfo"], "//conditions:default": [], @@ -449,13 +447,19 @@ xnnpack_cc_library( ], ) +xnnpack_cc_library( + name = "fp16", + hdrs = ["src/xnnpack/fp16.h"], + compatible_with = [], +) + xnnpack_cc_library( name = "math", hdrs = ["src/xnnpack/math.h"], deps = [ ":common", ":config_hdrs", - "@FP16", + ":fp16", ], ) @@ -598,7 +602,6 @@ xnnpack_cc_library( ":common", ":math", ":microparams", - "@FP16", ], ) @@ -777,7 +780,6 @@ xnnpack_cc_library( ":microparams", ":operator_h", ":xnnpack_h", - "@FP16", "@FXdiv", ], ) @@ -796,7 +798,6 @@ xnnpack_cxx_library( ":params", ":unaligned", ":xnnpack_h", - "@FP16", ] + xnnpack_if_kleidiai_enabled([ "@KleidiAI//kai/ukernels/matmul", "@KleidiAI//kai/ukernels/matmul:rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", @@ -931,6 +932,7 @@ xnnpack_cc_library( ":allocator", ":cache", ":common", + ":fp16", ":indirection", ":logging", ":math", @@ -946,7 +948,6 @@ xnnpack_cc_library( ":params", ":quantization", ":xnnpack_h", - "@FP16", "@pthreadpool", ] + select({ "//conditions:default": [], @@ -969,6 +970,7 @@ xnnpack_cc_library( ":cache", ":common", ":config_hdrs", + ":fp16", ":hardware_config", ":internal", ":logging", @@ -983,7 +985,6 @@ xnnpack_cc_library( ":params", ":requantization", ":xnnpack_h", - "@FP16", "@pthreadpool", ] + xnnpack_slinky_deps(), ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 948c69d8243e..1fdaf28f02a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -304,16 +304,6 @@ IF(NOT XNNPACK_USE_SYSTEM_LIBS) SET(CPUINFO_SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo-source" CACHE STRING "cpuinfo source directory") ENDIF() - IF(NOT DEFINED FP16_SOURCE_DIR) - MESSAGE(STATUS "Downloading FP16 to ${CMAKE_BINARY_DIR}/FP16-source (define FP16_SOURCE_DIR to avoid it)") - CONFIGURE_FILE(cmake/DownloadFP16.cmake "${CMAKE_BINARY_DIR}/FP16-download/CMakeLists.txt") - EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") - EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") - SET(FP16_SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" CACHE STRING "FP16 source directory") - ENDIF() - IF(NOT DEFINED FXDIV_SOURCE_DIR) MESSAGE(STATUS "Downloading FXdiv to ${CMAKE_BINARY_DIR}/FXdiv-source (define FXDIV_SOURCE_DIR to avoid it)") CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CMAKE_BINARY_DIR}/FXdiv-download/CMakeLists.txt") @@ -1116,42 +1106,7 @@ IF(XNNPACK_BUILD_LIBRARY) TARGET_LINK_LIBRARIES(XNNPACK PRIVATE fxdiv) ENDIF() -# ---[ Configure FP16 -IF(NOT TARGET fp16) - IF(NOT XNNPACK_USE_SYSTEM_LIBS) - SET(FP16_BUILD_TESTS OFF CACHE BOOL "") - SET(FP16_BUILD_BENCHMARKS OFF CACHE BOOL "") - ADD_SUBDIRECTORY( - "${FP16_SOURCE_DIR}" - "${CMAKE_BINARY_DIR}/FP16") - ELSE() - FIND_FILE(FP16_HDR fp16.h PATH_SUFFIXES include PATHS "${FP16_SOURCE_DIR}") - IF(NOT FP16_HDR) - MESSAGE(FATAL_ERROR "Cannot find fp16") - ENDIF() - ADD_LIBRARY(fp16 STATIC "${FP16_HDR}") - TARGET_INCLUDE_DIRECTORIES(fp16 INTERFACE "${FP16_SOURCE_DIR}/include") - SET_PROPERTY(TARGET fp16 PROPERTY LINKER_LANGUAGE C) - ENDIF() -ENDIF() -IF(XNNPACK_BUILD_ALL_MICROKERNELS) - TARGET_LINK_LIBRARIES(microkernels-all PRIVATE fp16) -ENDIF() -TARGET_LINK_LIBRARIES(microkernels-prod PRIVATE fp16) -TARGET_LINK_LIBRARIES(microparams-init PRIVATE fp16) -TARGET_LINK_LIBRARIES(packing PRIVATE fp16) -TARGET_LINK_LIBRARIES(indirection PRIVATE fp16) -TARGET_LINK_LIBRARIES(memory PRIVATE fp16) -TARGET_LINK_LIBRARIES(normalization PRIVATE fp16) -TARGET_LINK_LIBRARIES(microkernel-utils PRIVATE fp16) -TARGET_LINK_LIBRARIES(cache PRIVATE fp16) -TARGET_LINK_LIBRARIES(operator-utils PRIVATE fp16) IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(subgraph PRIVATE fp16) - TARGET_LINK_LIBRARIES(operators PRIVATE fp16) - TARGET_LINK_LIBRARIES(operator-run PRIVATE fp16) - - TARGET_LINK_LIBRARIES(XNNPACK PRIVATE fp16) INSTALL(TARGETS XNNPACK LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -1212,7 +1167,7 @@ IF(XNNPACK_BUILD_TESTS) ADD_LIBRARY(gemm-microkernel-tester STATIC test/gemm-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(gemm-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE xnnpack-base fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE xnnpack-base pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE packing) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(gemm-microkernel-tester PRIVATE kleidiai) @@ -1221,34 +1176,34 @@ IF(XNNPACK_BUILD_TESTS) ADD_LIBRARY(unary-operator-tester STATIC test/unary-operator-tester.cc) TARGET_INCLUDE_DIRECTORIES(unary-operator-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(unary-operator-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(unary-operator-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(dwconv-microkernel-tester STATIC test/dwconv-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(dwconv-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(dwconv-microkernel-tester PUBLIC next-prime) ADD_LIBRARY(vbinary-microkernel-tester STATIC test/vbinary-microkernel-tester.cc) SET_TARGET_PROPERTIES(vbinary-microkernel-tester PROPERTIES CXX_EXTENSIONS YES) TARGET_INCLUDE_DIRECTORIES(vbinary-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vbinary-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vbinary-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(vcvt-microkernel-tester STATIC test/vcvt-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(vcvt-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vcvt-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vcvt-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) ADD_LIBRARY(vunary-microkernel-tester STATIC test/vunary-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(vunary-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(vunary-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(vunary-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) TARGET_LINK_LIBRARIES(vunary-microkernel-tester PUBLIC next-prime) ADD_LIBRARY(convolution-test-helpers OBJECT test/convolution-test-helpers.cc) TARGET_INCLUDE_DIRECTORIES(convolution-test-helpers PRIVATE include src) - TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE xnnpack-base fp16) + TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE xnnpack-base) ADD_LIBRARY(packq-microkernel-tester STATIC test/packq-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(packq-microkernel-tester PRIVATE . include src test) - TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK pthreadpool GTest::gtest) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE kleidiai) ENDIF() @@ -1268,7 +1223,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gtest GTest::gtest_main hardware-config @@ -1295,7 +1249,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1333,7 +1286,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gtest GTest::gtest_main unary-operator-tester @@ -1344,7 +1296,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(binary-elementwise-nd-test test/binary-elementwise-nd.cc) TARGET_INCLUDE_DIRECTORIES(binary-elementwise-nd-test PRIVATE src test) TARGET_LINK_LIBRARIES(binary-elementwise-nd-test PRIVATE - fp16 GTest::gtest GTest::gtest_main XNNPACK) @@ -1362,7 +1313,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1425,7 +1375,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1444,7 +1393,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE convolution-test-helpers - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1539,7 +1487,6 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1580,7 +1527,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE dwconv-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1633,7 +1579,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE gemm-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1654,7 +1599,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE packq-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1727,7 +1671,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vbinary-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1760,7 +1703,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vcvt-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1814,7 +1756,6 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE . include src test) TARGET_LINK_LIBRARIES(${TEST}-test PRIVATE vunary-microkernel-tester - fp16 GTest::gmock GTest::gtest GTest::gtest_main @@ -1849,7 +1790,7 @@ IF(XNNPACK_BUILD_TESTS) ADD_EXECUTABLE(operator-utils-test test/operator-utils.cc) TARGET_INCLUDE_DIRECTORIES(operator-utils-test PRIVATE include src) - TARGET_LINK_LIBRARIES(operator-utils-test PRIVATE XNNPACK GTest::gtest GTest::gtest_main pthreadpool fp16) + TARGET_LINK_LIBRARIES(operator-utils-test PRIVATE XNNPACK GTest::gtest GTest::gtest_main pthreadpool) ENDIF() @@ -1903,14 +1844,14 @@ IF(XNNPACK_BUILD_BENCHMARKS) # Helper libraries ADD_LIBRARY(packq-benchmark STATIC bench/packq-benchmark.cc) TARGET_INCLUDE_DIRECTORIES(packq-benchmark PRIVATE . include src bench) - TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils fp16) + TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE kleidiai) ENDIF() ADD_LIBRARY(gemm-benchmark STATIC bench/gemm-benchmark.cc) TARGET_INCLUDE_DIRECTORIES(gemm-benchmark PRIVATE . include src bench) - TARGET_LINK_LIBRARIES(gemm-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils fp16) + TARGET_LINK_LIBRARIES(gemm-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) IF(XNNPACK_ENABLE_KLEIDIAI) TARGET_LINK_LIBRARIES(gemm-benchmark PUBLIC kleidiai) ENDIF() @@ -1936,7 +1877,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(bench-models PRIVATE bench-utils benchmark::benchmark - fp16 models XNNPACK) @@ -1971,7 +1911,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(${BENCH}-bench PRIVATE bench-utils benchmark::benchmark - fp16 XNNPACK ) ENDFOREACH() @@ -2068,7 +2007,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(${BENCH}-bench PRIVATE bench-utils benchmark::benchmark - fp16 gemm-benchmark hardware-config im2col diff --git a/WORKSPACE b/WORKSPACE index 54d3841939ee..b140e022e055 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -63,19 +63,6 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadGoogleBenchmark.cmake) -# LINT.IfChange -# FP16 library, used for half-precision conversions -http_archive( - name = "FP16", - build_file = "@//third_party:FP16.BUILD", - sha256 = "e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70", - strip_prefix = "FP16-0a92994d729ff76a58f692d3028ca1b64b145d91", - urls = [ - "https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip", - ], -) -# LINT.ThenChange(cmake/DownloadFP16.cmake) - # LINT.IfChange # FXdiv library, used for repeated integer division by the same factor http_archive( diff --git a/bench/BUILD.bazel b/bench/BUILD.bazel index a8758ea95393..fbe9444eadf0 100644 --- a/bench/BUILD.bazel +++ b/bench/BUILD.bazel @@ -27,7 +27,6 @@ load( MICROKERNEL_BENCHMARK_DEPS = [ ":bench_utils", - "@FP16", "//:aligned_allocator", "//:all_microkernels", "//:common", @@ -48,7 +47,6 @@ OPERATOR_BENCHMARK_DEPS = [ "//:cache", "//:common", "//:math", - "@FP16", ] xnnpack_cxx_library( diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index 2a9b65765294..116acbfdbd29 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -25,7 +25,6 @@ #include #include -#include #include "bench/utils.h" #include diff --git a/build_params.bzl b/build_params.bzl index 6aec15849712..92b4025364d0 100644 --- a/build_params.bzl +++ b/build_params.bzl @@ -274,7 +274,6 @@ XNNPACK_PARAMS_FOR_ARCH = { ], extra_deps = [ "//:config_hdrs", - "@FP16", "@FXdiv", ], ), @@ -523,7 +522,6 @@ XNNPACK_PARAMS_FOR_ARCH = { "-mno-sse4.2", ], extra_deps = [ - "@FP16", ], msvc_x86_32_copts = ["/arch:SSE2"], msvc_x86_64_copts = ["/arch:SSE2"], diff --git a/cmake/DownloadFP16.cmake b/cmake/DownloadFP16.cmake deleted file mode 100644 index a3321d9d40b5..000000000000 --- a/cmake/DownloadFP16.cmake +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# Copyright 2019 Google LLC -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) - -PROJECT(fp16-download NONE) - -# Set file timestamps to the time of extraction. -IF(POLICY CMP0135) - CMAKE_POLICY(SET CMP0135 NEW) -ENDIF() - -INCLUDE(ExternalProject) -ExternalProject_Add(fp16 - URL https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip - URL_HASH SHA256=e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70 - SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" - BINARY_DIR "${CMAKE_BINARY_DIR}/FP16" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) diff --git a/cmake/gen/avxvnniint8_microkernels.cmake b/cmake/gen/avxvnniint8_microkernels.cmake index 60e0412af5f1..ee2ac244a817 100644 --- a/cmake/gen/avxvnniint8_microkernels.cmake +++ b/cmake/gen/avxvnniint8_microkernels.cmake @@ -15,6 +15,7 @@ SET(PROD_AVXVNNIINT8_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8c8-minmax-fp32-avxvnniint8-prfm.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x8c8-minmax-fp32-avxvnniint8-prfm.c) -SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS) +SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c) SET(ALL_AVXVNNIINT8_MICROKERNEL_SRCS ${PROD_AVXVNNIINT8_MICROKERNEL_SRCS} + ${NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS}) diff --git a/cmake/gen/rvvfp16arith_microkernels.cmake b/cmake/gen/rvvfp16arith_microkernels.cmake index 0f32a4a46ae3..03df4b84f53b 100644 --- a/cmake/gen/rvvfp16arith_microkernels.cmake +++ b/cmake/gen/rvvfp16arith_microkernels.cmake @@ -15,6 +15,10 @@ SET(NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u1v.c src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u2v.c src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u4v.c - src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c) + src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c + src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c) SET(ALL_RVVFP16ARITH_MICROKERNEL_SRCS ${PROD_RVVFP16ARITH_MICROKERNEL_SRCS} + ${NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS}) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index d30bceb42576..8879af0b6083 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -635,7 +635,9 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c diff --git a/gen/avxvnniint8_microkernels.bzl b/gen/avxvnniint8_microkernels.bzl index 97f1657ee4a2..a1b149b017c2 100644 --- a/gen/avxvnniint8_microkernels.bzl +++ b/gen/avxvnniint8_microkernels.bzl @@ -13,6 +13,7 @@ PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ ] NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c", ] ALL_AVXVNNIINT8_MICROKERNEL_SRCS = PROD_AVXVNNIINT8_MICROKERNEL_SRCS + NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS diff --git a/gen/rvvfp16arith_microkernels.bzl b/gen/rvvfp16arith_microkernels.bzl index 139aa8722540..3d0d9afc5597 100644 --- a/gen/rvvfp16arith_microkernels.bzl +++ b/gen/rvvfp16arith_microkernels.bzl @@ -13,6 +13,10 @@ NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS = [ "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u2v.c", "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u4v.c", "src/f16-vclamp/gen/f16-vclamp-rvvfp16arith-u8v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c", + "src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c", ] ALL_RVVFP16ARITH_MICROKERNEL_SRCS = PROD_RVVFP16ARITH_MICROKERNEL_SRCS + NON_PROD_RVVFP16ARITH_MICROKERNEL_SRCS diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index 7ef2fc7f766c..21e3b9f9b0cc 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -632,7 +632,9 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c", "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c", "src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c", diff --git a/scripts/generate-f32-f16-vcvt.sh b/scripts/generate-f32-f16-vcvt.sh index e34f227224f6..d7dae496c538 100755 --- a/scripts/generate-f32-f16-vcvt.sh +++ b/scripts/generate-f32-f16-vcvt.sh @@ -13,6 +13,12 @@ tools/xngen src/f32-f16-vcvt/neon.c.in -D BATCH_TILE=32 -o src/f32-f16-vcvt/gen/ tools/xngen src/f32-f16-vcvt/neonfp16.c.in -D BATCH_TILE=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-neonfp16-u8.c & tools/xngen src/f32-f16-vcvt/neonfp16.c.in -D BATCH_TILE=16 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-neonfp16-u16.c & +################################ RISC-V Vector ################################ +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=1 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=2 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=4 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c & +tools/xngen src/f32-f16-vcvt/rvvfp16arith.c.in -D LMUL=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c & + ################################# x86 128-bit ################################# tools/xngen src/f32-f16-vcvt/sse.c.in -D SSE=2 -D AVX=0 -D BATCH_TILE=8 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-sse2-u8.c & tools/xngen src/f32-f16-vcvt/sse.c.in -D SSE=2 -D AVX=0 -D BATCH_TILE=16 -o src/f32-f16-vcvt/gen/f32-f16-vcvt-sse2-u16.c & diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 281bba59237a..912f9400796b 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -23,4 +23,11 @@ tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=4 -D TYPE=int8_t -o src/q tools/xngen src/x8-packw/kr-scalar.c.in -D NR=32 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c & tools/xngen src/x8-packw/kr-scalar.c.in -D NR=64 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c & + +### C8 packing +tools/xngen src/x8-packw/kr-scalar.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c & +tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c & + +tools/xngen src/x8-packw/kr-avxvnniint8.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c & + wait diff --git a/scripts/genxnn b/scripts/genxnn new file mode 100755 index 000000000000..12a5888c8c9c --- /dev/null +++ b/scripts/genxnn @@ -0,0 +1,73 @@ +#!/bin/bash + +function my_generate1() { + files_changed=$(scripts/check_files_changed.py $1) + if [[ $files_changed ]] + then + if ${VERBOSE}; then + echo $1 + fi + if ${DEBUG}; then + echo $files_changed + fi + touch $files_changed + bash -c $1 & + fi +} + +function my_generatef() { + if ${VERBOSE}; then + echo $1 + fi + bash -c $1 & +} + +export -f my_generate1 +export -f my_generatef + +export FORCE='false' +export VERBOSE='false' +export DEBUG='false' + +while getopts ':vfd' 'OPTKEY'; do + case ${OPTKEY} in + 'v') + export VERBOSE='true' + ;; + 'd') + export DEBUG='true' + ;; + 'f') + export FORCE='true' + ;; + '?') + echo "INVALID OPTION -- ${OPTARG}" >&2 + exit 1 + ;; + ':') + echo "MISSING ARGUMENT for option -- ${OPTARG}" >&2 + exit 1 + ;; + *) + echo "UNIMPLEMENTED OPTION -- ${OPTKEY}" >&2 + exit 1 + ;; + esac +done + +# [optional] Remove all options processed by getopts. +shift $(( OPTIND - 1 )) +[[ "${1}" == "--" ]] && shift + +#pushd $(g4 g4d) +if ${FORCE}; then + find scripts/ -name 'generate-*.sh' -exec bash -c 'my_generatef {}' \; +else + find scripts/ -name 'generate-*.sh' -exec bash -c 'my_generate1 {}' \; +fi +wait +if ${VERBOSE}; then + echo ./tools/update-microkernels.py +fi +./tools/update-microkernels.py +#popd diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index b464b1af0ba0..26a7d8250d60 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -1999,7 +1999,7 @@ static void init_qs8_lrelu_config(void) { } static void init_qs8_to_f16_cvt_config(void) { - #if XNN_ARCH_ARM || XNN_ARCH_ARM64 + #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); if (hardware_config->use_arm_neon_fp16_arith) { diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c index 20469b41d53c..fb9e4dc17b68 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u1( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c index 017b78406bfc..e3561a3ccb31 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u2.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u2( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c index c88ab86b633d..f088ddf56847 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u3.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u3( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c index c9bd1b86128c..c2cb7a3594ec 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u4.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_fmagic_u4( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c index 7d9bda6d2020..c995a3df11a7 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u1.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u1( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c index 82074e6fcafe..00620ce80b3f 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u2.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u2( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c index 3ccb4aaff006..c4f849e687e8 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u3.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u3( size_t batch, diff --git a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c index 979725bdce94..86e42c4ca578 100644 --- a/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c +++ b/src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-imagic-u4.c @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -#include void xnn_f16_qs8_vcvt_ukernel__scalar_imagic_u4( size_t batch, diff --git a/src/f32-f16-vcvt/f32-f16-vcvt.h b/src/f32-f16-vcvt/f32-f16-vcvt.h index dfd6fb266f54..fbcaa08526d8 100644 --- a/src/f32-f16-vcvt/f32-f16-vcvt.h +++ b/src/f32-f16-vcvt/f32-f16-vcvt.h @@ -58,6 +58,13 @@ XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__wasmrelaxedsimd_u24, 24 XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__wasmrelaxedsimd_u32, 32, false, float, xnn_float16, void, NULL) #endif // XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_FP16_VECTOR +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u1v, 1, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u2v, 2, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u4v, 4, true, float, xnn_float16, void, NULL) +XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__rvv_u8v, 8, true, float, xnn_float16, void, NULL) +#endif + XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u1, 1, false, float, xnn_float16, void, NULL) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u2, 2, false, float, xnn_float16, void, NULL) XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_f32_f16_vcvt_ukernel__scalar_bitcast_u3, 3, false, float, xnn_float16, void, NULL) diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c new file mode 100644 index 000000000000..e84b6d04699d --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u1v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u1v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m1(batch); batch -= n; + + vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n; + + vfloat16mf2_t y_f16v = __riscv_vfncvt_f_f_w_f16mf2(x_f32v, n); + + __riscv_vse16_v_f16mf2(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c new file mode 100644 index 000000000000..69dd22f841e1 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u2v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u2v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m2(batch); batch -= n; + + vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n; + + vfloat16m1_t y_f16v = __riscv_vfncvt_f_f_w_f16m1(x_f32v, n); + + __riscv_vse16_v_f16m1(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c new file mode 100644 index 000000000000..116bb4d35bb7 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u4v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u4v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m4(batch); batch -= n; + + vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n; + + vfloat16m2_t y_f16v = __riscv_vfncvt_f_f_w_f16m2(x_f32v, n); + + __riscv_vse16_v_f16m2(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c new file mode 100644 index 000000000000..dc0cadfb2397 --- /dev/null +++ b/src/f32-f16-vcvt/gen/f32-f16-vcvt-rvvfp16arith-u8v.c @@ -0,0 +1,40 @@ +// Auto-generated file. Do not edit! +// Template: src/f32-f16-vcvt/rvvfp16arith.c.in +// Generator: tools/xngen +// +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u8v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m8(batch); batch -= n; + + vfloat32m8_t x_f32v = __riscv_vle32_v_f32m8(input, n); input += n; + + vfloat16m4_t y_f16v = __riscv_vfncvt_f_f_w_f16m4(x_f32v, n); + + __riscv_vse16_v_f16m4(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-f16-vcvt/rvvfp16arith.c.in b/src/f32-f16-vcvt/rvvfp16arith.c.in new file mode 100644 index 000000000000..da1360472683 --- /dev/null +++ b/src/f32-f16-vcvt/rvvfp16arith.c.in @@ -0,0 +1,38 @@ +// Copyright 2024 Imagination Technologies, Inc. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert LMUL in [1, 2, 4, 8] +$LMUL_16 = {1: "f2", 2: "1", 4: "2", 8: "4"}[LMUL] +#include + +#include + +#include "xnnpack/vcvt.h" + + +void xnn_f32_f16_vcvt_ukernel__rvvfp16arith_u${LMUL}v( + size_t batch, + const float* input, + void* output, + const void* params) +{ + assert(batch != 0); + assert(batch % sizeof(float) == 0); + assert(input != NULL); + assert(output != NULL); + + batch >>= XNN_LOG2_SIZEOF_FLOAT; + + _Float16* o = (_Float16*) output; + for (; batch > 0;) { + const int32_t n = __riscv_vsetvl_e32m${LMUL}(batch); batch -= n; + + vfloat32m${LMUL}_t x_f32v = __riscv_vle32_v_f32m${LMUL}(input, n); input += n; + + vfloat16m${LMUL_16}_t y_f16v = __riscv_vfncvt_f_f_w_f16m${LMUL_16}(x_f32v, n); + + __riscv_vse16_v_f16m${LMUL_16}(o, y_f16v, n); o += n; + } +} diff --git a/src/f32-qs8-vcvt/scalar-fmagic.c.in b/src/f32-qs8-vcvt/scalar-fmagic.c.in index 7c3bfdfa0ba0..94218c1745a7 100644 --- a/src/f32-qs8-vcvt/scalar-fmagic.c.in +++ b/src/f32-qs8-vcvt/scalar-fmagic.c.in @@ -12,8 +12,6 @@ $assert IDATATYPE == "F16" and ODATATYPE == "QS8" or IDATATYPE == "F32" #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -$if IDATATYPE == "F16": - #include $INPUT_T = {"F16": "xnn_float16", "F32": "float"}[IDATATYPE] $XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[ODATATYPE] diff --git a/src/f32-qs8-vcvt/scalar-imagic.c.in b/src/f32-qs8-vcvt/scalar-imagic.c.in index ae1e474b3ec3..2d83fb24ed04 100644 --- a/src/f32-qs8-vcvt/scalar-imagic.c.in +++ b/src/f32-qs8-vcvt/scalar-imagic.c.in @@ -12,8 +12,6 @@ $assert IDATATYPE == "F16" and ODATATYPE == "QS8" or IDATATYPE == "F32" #include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/vcvt.h" -$if IDATATYPE == "F16": - #include $INPUT_T = {"F16": "xnn_float16", "F32": "float"}[IDATATYPE] $XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[ODATATYPE] diff --git a/src/f32-rminmax/scalar.c.in b/src/f32-rminmax/scalar.c.in index e26a6584b6ca..1fbf58ff9517 100644 --- a/src/f32-rminmax/scalar.c.in +++ b/src/f32-rminmax/scalar.c.in @@ -12,7 +12,7 @@ $if not WASM: #include "xnnpack/math.h" #include "xnnpack/reduce.h" $if DATATYPE == "F16": - #include + #include "xnnpack/fp16.h" $ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS $MAX = "__builtin_wasm_max_f32" if WASM else "math_max_f32" @@ -30,7 +30,7 @@ void xnn_${DATATYPE.lower()}_r${OP.lower()}_ukernel__${ISA}_u${BATCH_TILE}${ACC_ $elif DATATYPE == "F16": const void* input, void* output, - const union xnn_${DATATYPE.lower()}_default_params params[restrict XNN_MIN_ELEMENTS(1)]) + const struct xnn_${DATATYPE.lower()}_default_params params[restrict XNN_MIN_ELEMENTS(1)]) { assert(batch != 0); assert(batch % sizeof(${ITYPE}) == 0); diff --git a/src/operators/average-pooling-nhwc.c b/src/operators/average-pooling-nhwc.c index 754da0202024..7a217c37d856 100644 --- a/src/operators/average-pooling-nhwc.c +++ b/src/operators/average-pooling-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/convolution-nchw.c b/src/operators/convolution-nchw.c index fc4fe8c9d5bc..6be0bf5a2b1e 100644 --- a/src/operators/convolution-nchw.c +++ b/src/operators/convolution-nchw.c @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 0125b6cb3cb6..e0eb6265ccb6 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/deconvolution-nhwc.c b/src/operators/deconvolution-nhwc.c index 60a44bc158bb..ddfe96507232 100644 --- a/src/operators/deconvolution-nhwc.c +++ b/src/operators/deconvolution-nhwc.c @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/cache.h" diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index ca5d69fdc6f5..c3eb10fab038 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/global-average-pooling-ncw.c b/src/operators/global-average-pooling-ncw.c index ef5ce1a7749a..6e7c37002005 100644 --- a/src/operators/global-average-pooling-ncw.c +++ b/src/operators/global-average-pooling-ncw.c @@ -10,13 +10,13 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microparams-init.h" diff --git a/src/operators/global-average-pooling-nwc.c b/src/operators/global-average-pooling-nwc.c index fa50945dc5b7..e04b7943258a 100644 --- a/src/operators/global-average-pooling-nwc.c +++ b/src/operators/global-average-pooling-nwc.c @@ -14,13 +14,13 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microparams.h" diff --git a/src/operators/max-pooling-nhwc.c b/src/operators/max-pooling-nhwc.c index 43ad86552cdc..e8dd7413d371 100644 --- a/src/operators/max-pooling-nhwc.c +++ b/src/operators/max-pooling-nhwc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/scaled-dot-product-attention-nhtc.c b/src/operators/scaled-dot-product-attention-nhtc.c index 1de22ee1d87b..a278c7c044f5 100644 --- a/src/operators/scaled-dot-product-attention-nhtc.c +++ b/src/operators/scaled-dot-product-attention-nhtc.c @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/softmax-nc.c b/src/operators/softmax-nc.c index 0d4db7f1443b..75e57edc9c99 100644 --- a/src/operators/softmax-nc.c +++ b/src/operators/softmax-nc.c @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c index 0cd6baa99553..2270934695aa 100644 --- a/src/operators/unary-elementwise-nc.c +++ b/src/operators/unary-elementwise-nc.c @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" diff --git a/src/packing.cc b/src/packing.cc index 9c742ad6e8c9..c37b06b59029 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -28,7 +28,6 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #endif // XNN_ENABLE_KLEIDIAI -#include extern "C" { diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c index 80836d13ca57..5be5406c57f9 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c @@ -12,7 +12,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar( @@ -98,11 +97,11 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar( w = (const float*) w + 2; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c index b84b84c0f1a1..215aa89d3971 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar( @@ -127,13 +126,13 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout0x2 = math_max_f32(vout0x2, voutput_min); vout0x3 = math_max_f32(vout0x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); vout0x2 = math_min_f32(vout0x2, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c index 73bf59de5a0f..d42717834c33 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( @@ -187,7 +186,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( w = (const float*) w + 8; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout0x2 = math_max_f32(vout0x2, voutput_min); @@ -197,7 +196,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar( vout0x6 = math_max_f32(vout0x6, voutput_min); vout0x7 = math_max_f32(vout0x7, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); vout0x2 = math_min_f32(vout0x2, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c index 660ea399c00e..60f320477311 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x2-minmax-scalar.c @@ -12,7 +12,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar( @@ -127,13 +126,13 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar( w = (const float*) w + 2; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); vout1x1 = math_max_f32(vout1x1, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c index 580148c13b16..a75b95cc64f1 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( @@ -174,7 +173,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); @@ -184,7 +183,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar( vout0x3 = math_max_f32(vout0x3, voutput_min); vout1x3 = math_max_f32(vout1x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c index c5edd90fbb5f..523d285b7e22 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( @@ -270,7 +269,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( w = (const float*) w + 8; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout0x1 = math_max_f32(vout0x1, voutput_min); @@ -288,7 +287,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar( vout0x7 = math_max_f32(vout0x7, voutput_min); vout1x7 = math_max_f32(vout1x7, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout0x1 = math_min_f32(vout0x1, voutput_max); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c index 33b41fa79e56..e8855d10d4a7 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x4-minmax-scalar.c @@ -11,7 +11,6 @@ #include "xnnpack/gemm.h" #include "xnnpack/math.h" -#include void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( @@ -268,7 +267,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( w = (const float*) w + 4; - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); vout0x0 = math_max_f32(vout0x0, voutput_min); vout1x0 = math_max_f32(vout1x0, voutput_min); vout2x0 = math_max_f32(vout2x0, voutput_min); @@ -286,7 +285,7 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar( vout2x3 = math_max_f32(vout2x3, voutput_min); vout3x3 = math_max_f32(vout3x3, voutput_min); - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); vout0x0 = math_min_f32(vout0x0, voutput_max); vout1x0 = math_min_f32(vout1x0, voutput_max); vout2x0 = math_min_f32(vout2x0, voutput_max); diff --git a/src/qs8-gemm/scalar.c.in b/src/qs8-gemm/scalar.c.in index fbe2c29b299e..97285dd32cb8 100644 --- a/src/qs8-gemm/scalar.c.in +++ b/src/qs8-gemm/scalar.c.in @@ -14,8 +14,6 @@ $if VARIANT == "LRINTF": #include "xnnpack/math.h" $if NR % 4 != 0: #include "xnnpack/unaligned.h" -$if DATATYPE in ["QB4_F16"]: - #include $# $INDENT = 0 @@ -279,7 +277,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}_ w = (const float*) w + ${NR * 2}; $if DATATYPE in ["QB4_F16"]: - const float voutput_min = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.min); + const float voutput_min = xnn_float16_to_float(params->scalar.min); $else: const float voutput_min = params->scalar.min; $for N in range(NR): @@ -287,7 +285,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}_ vout${M}x${N} = ${MAX_F32}(vout${M}x${N}, voutput_min); $if DATATYPE in ["QB4_F16"]: - const float voutput_max = fp16_ieee_to_fp32_value(*(const uint16_t*) ¶ms->scalar.max); + const float voutput_max = xnn_float16_to_float(params->scalar.max); $else: const float voutput_max = params->scalar.max; $for N in range(NR): diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c new file mode 100644 index 000000000000..0b7b42c16b36 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c @@ -0,0 +1,2319 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = params ? (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + ((int32_t*) out)[0] = b[0]; + ((int32_t*) out)[1] = b[1]; + ((int32_t*) out)[2] = b[2]; + ((int32_t*) out)[3] = b[3]; + ((int32_t*) out)[4] = b[4]; + ((int32_t*) out)[5] = b[5]; + ((int32_t*) out)[6] = b[6]; + ((int32_t*) out)[7] = b[7]; + ((int32_t*) out)[8] = b[8]; + ((int32_t*) out)[9] = b[9]; + ((int32_t*) out)[10] = b[10]; + ((int32_t*) out)[11] = b[11]; + ((int32_t*) out)[12] = b[12]; + ((int32_t*) out)[13] = b[13]; + ((int32_t*) out)[14] = b[14]; + ((int32_t*) out)[15] = b[15]; + b += 16; + } else { + ((int32_t*) out)[0] = 0; + ((int32_t*) out)[1] = 0; + ((int32_t*) out)[2] = 0; + ((int32_t*) out)[3] = 0; + ((int32_t*) out)[4] = 0; + ((int32_t*) out)[5] = 0; + ((int32_t*) out)[6] = 0; + ((int32_t*) out)[7] = 0; + ((int32_t*) out)[8] = 0; + ((int32_t*) out)[9] = 0; + ((int32_t*) out)[10] = 0; + ((int32_t*) out)[11] = 0; + ((int32_t*) out)[12] = 0; + ((int32_t*) out)[13] = 0; + ((int32_t*) out)[14] = 0; + ((int32_t*) out)[15] = 0; + } + out += 16 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + uint32_t ksum8 = 0; + uint32_t ksum9 = 0; + uint32_t ksum10 = 0; + uint32_t ksum11 = 0; + uint32_t ksum12 = 0; + uint32_t ksum13 = 0; + uint32_t ksum14 = 0; + uint32_t ksum15 = 0; + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + const int8_t v8x0 = w8[0]; + const int8_t v8x1 = w8[1]; + const int8_t v8x2 = w8[2]; + const int8_t v8x3 = w8[3]; + const int8_t v8x4 = w8[4]; + const int8_t v8x5 = w8[5]; + const int8_t v8x6 = w8[6]; + const int8_t v8x7 = w8[7]; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + out[64] = v8x0; + out[65] = v8x1; + out[66] = v8x2; + out[67] = v8x3; + out[68] = v8x4; + out[69] = v8x5; + out[70] = v8x6; + out[71] = v8x7; + w8 += 8; + const int8_t v9x0 = w9[0]; + const int8_t v9x1 = w9[1]; + const int8_t v9x2 = w9[2]; + const int8_t v9x3 = w9[3]; + const int8_t v9x4 = w9[4]; + const int8_t v9x5 = w9[5]; + const int8_t v9x6 = w9[6]; + const int8_t v9x7 = w9[7]; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + out[72] = v9x0; + out[73] = v9x1; + out[74] = v9x2; + out[75] = v9x3; + out[76] = v9x4; + out[77] = v9x5; + out[78] = v9x6; + out[79] = v9x7; + w9 += 8; + const int8_t v10x0 = w10[0]; + const int8_t v10x1 = w10[1]; + const int8_t v10x2 = w10[2]; + const int8_t v10x3 = w10[3]; + const int8_t v10x4 = w10[4]; + const int8_t v10x5 = w10[5]; + const int8_t v10x6 = w10[6]; + const int8_t v10x7 = w10[7]; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + out[80] = v10x0; + out[81] = v10x1; + out[82] = v10x2; + out[83] = v10x3; + out[84] = v10x4; + out[85] = v10x5; + out[86] = v10x6; + out[87] = v10x7; + w10 += 8; + const int8_t v11x0 = w11[0]; + const int8_t v11x1 = w11[1]; + const int8_t v11x2 = w11[2]; + const int8_t v11x3 = w11[3]; + const int8_t v11x4 = w11[4]; + const int8_t v11x5 = w11[5]; + const int8_t v11x6 = w11[6]; + const int8_t v11x7 = w11[7]; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + out[88] = v11x0; + out[89] = v11x1; + out[90] = v11x2; + out[91] = v11x3; + out[92] = v11x4; + out[93] = v11x5; + out[94] = v11x6; + out[95] = v11x7; + w11 += 8; + const int8_t v12x0 = w12[0]; + const int8_t v12x1 = w12[1]; + const int8_t v12x2 = w12[2]; + const int8_t v12x3 = w12[3]; + const int8_t v12x4 = w12[4]; + const int8_t v12x5 = w12[5]; + const int8_t v12x6 = w12[6]; + const int8_t v12x7 = w12[7]; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + out[96] = v12x0; + out[97] = v12x1; + out[98] = v12x2; + out[99] = v12x3; + out[100] = v12x4; + out[101] = v12x5; + out[102] = v12x6; + out[103] = v12x7; + w12 += 8; + const int8_t v13x0 = w13[0]; + const int8_t v13x1 = w13[1]; + const int8_t v13x2 = w13[2]; + const int8_t v13x3 = w13[3]; + const int8_t v13x4 = w13[4]; + const int8_t v13x5 = w13[5]; + const int8_t v13x6 = w13[6]; + const int8_t v13x7 = w13[7]; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + out[104] = v13x0; + out[105] = v13x1; + out[106] = v13x2; + out[107] = v13x3; + out[108] = v13x4; + out[109] = v13x5; + out[110] = v13x6; + out[111] = v13x7; + w13 += 8; + const int8_t v14x0 = w14[0]; + const int8_t v14x1 = w14[1]; + const int8_t v14x2 = w14[2]; + const int8_t v14x3 = w14[3]; + const int8_t v14x4 = w14[4]; + const int8_t v14x5 = w14[5]; + const int8_t v14x6 = w14[6]; + const int8_t v14x7 = w14[7]; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + out[112] = v14x0; + out[113] = v14x1; + out[114] = v14x2; + out[115] = v14x3; + out[116] = v14x4; + out[117] = v14x5; + out[118] = v14x6; + out[119] = v14x7; + w14 += 8; + const int8_t v15x0 = w15[0]; + const int8_t v15x1 = w15[1]; + const int8_t v15x2 = w15[2]; + const int8_t v15x3 = w15[3]; + const int8_t v15x4 = w15[4]; + const int8_t v15x5 = w15[5]; + const int8_t v15x6 = w15[6]; + const int8_t v15x7 = w15[7]; + ksum15 += (uint32_t) v15x0; + ksum15 += (uint32_t) v15x1; + ksum15 += (uint32_t) v15x2; + ksum15 += (uint32_t) v15x3; + ksum15 += (uint32_t) v15x4; + ksum15 += (uint32_t) v15x5; + ksum15 += (uint32_t) v15x6; + ksum15 += (uint32_t) v15x7; + out[120] = v15x0; + out[121] = v15x1; + out[122] = v15x2; + out[123] = v15x3; + out[124] = v15x4; + out[125] = v15x5; + out[126] = v15x6; + out[127] = v15x7; + w15 += 8; + out += 128; + } + + // KC remainder 1..KR-1 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + const int8_t v8x0 = 0 < k ? w8[0] : izp; + const int8_t v8x1 = 1 < k ? w8[1] : izp; + const int8_t v8x2 = 2 < k ? w8[2] : izp; + const int8_t v8x3 = 3 < k ? w8[3] : izp; + const int8_t v8x4 = 4 < k ? w8[4] : izp; + const int8_t v8x5 = 5 < k ? w8[5] : izp; + const int8_t v8x6 = 6 < k ? w8[6] : izp; + const int8_t v8x7 = 7 < k ? w8[7] : izp; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + if (0 < k) { + out[64] = v8x0; + } + if (1 < k) { + out[65] = v8x1; + } + if (2 < k) { + out[66] = v8x2; + } + if (3 < k) { + out[67] = v8x3; + } + if (4 < k) { + out[68] = v8x4; + } + if (5 < k) { + out[69] = v8x5; + } + if (6 < k) { + out[70] = v8x6; + } + if (7 < k) { + out[71] = v8x7; + } + w8 += 8; + const int8_t v9x0 = 0 < k ? w9[0] : izp; + const int8_t v9x1 = 1 < k ? w9[1] : izp; + const int8_t v9x2 = 2 < k ? w9[2] : izp; + const int8_t v9x3 = 3 < k ? w9[3] : izp; + const int8_t v9x4 = 4 < k ? w9[4] : izp; + const int8_t v9x5 = 5 < k ? w9[5] : izp; + const int8_t v9x6 = 6 < k ? w9[6] : izp; + const int8_t v9x7 = 7 < k ? w9[7] : izp; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + if (0 < k) { + out[72] = v9x0; + } + if (1 < k) { + out[73] = v9x1; + } + if (2 < k) { + out[74] = v9x2; + } + if (3 < k) { + out[75] = v9x3; + } + if (4 < k) { + out[76] = v9x4; + } + if (5 < k) { + out[77] = v9x5; + } + if (6 < k) { + out[78] = v9x6; + } + if (7 < k) { + out[79] = v9x7; + } + w9 += 8; + const int8_t v10x0 = 0 < k ? w10[0] : izp; + const int8_t v10x1 = 1 < k ? w10[1] : izp; + const int8_t v10x2 = 2 < k ? w10[2] : izp; + const int8_t v10x3 = 3 < k ? w10[3] : izp; + const int8_t v10x4 = 4 < k ? w10[4] : izp; + const int8_t v10x5 = 5 < k ? w10[5] : izp; + const int8_t v10x6 = 6 < k ? w10[6] : izp; + const int8_t v10x7 = 7 < k ? w10[7] : izp; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + if (0 < k) { + out[80] = v10x0; + } + if (1 < k) { + out[81] = v10x1; + } + if (2 < k) { + out[82] = v10x2; + } + if (3 < k) { + out[83] = v10x3; + } + if (4 < k) { + out[84] = v10x4; + } + if (5 < k) { + out[85] = v10x5; + } + if (6 < k) { + out[86] = v10x6; + } + if (7 < k) { + out[87] = v10x7; + } + w10 += 8; + const int8_t v11x0 = 0 < k ? w11[0] : izp; + const int8_t v11x1 = 1 < k ? w11[1] : izp; + const int8_t v11x2 = 2 < k ? w11[2] : izp; + const int8_t v11x3 = 3 < k ? w11[3] : izp; + const int8_t v11x4 = 4 < k ? w11[4] : izp; + const int8_t v11x5 = 5 < k ? w11[5] : izp; + const int8_t v11x6 = 6 < k ? w11[6] : izp; + const int8_t v11x7 = 7 < k ? w11[7] : izp; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + if (0 < k) { + out[88] = v11x0; + } + if (1 < k) { + out[89] = v11x1; + } + if (2 < k) { + out[90] = v11x2; + } + if (3 < k) { + out[91] = v11x3; + } + if (4 < k) { + out[92] = v11x4; + } + if (5 < k) { + out[93] = v11x5; + } + if (6 < k) { + out[94] = v11x6; + } + if (7 < k) { + out[95] = v11x7; + } + w11 += 8; + const int8_t v12x0 = 0 < k ? w12[0] : izp; + const int8_t v12x1 = 1 < k ? w12[1] : izp; + const int8_t v12x2 = 2 < k ? w12[2] : izp; + const int8_t v12x3 = 3 < k ? w12[3] : izp; + const int8_t v12x4 = 4 < k ? w12[4] : izp; + const int8_t v12x5 = 5 < k ? w12[5] : izp; + const int8_t v12x6 = 6 < k ? w12[6] : izp; + const int8_t v12x7 = 7 < k ? w12[7] : izp; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + if (0 < k) { + out[96] = v12x0; + } + if (1 < k) { + out[97] = v12x1; + } + if (2 < k) { + out[98] = v12x2; + } + if (3 < k) { + out[99] = v12x3; + } + if (4 < k) { + out[100] = v12x4; + } + if (5 < k) { + out[101] = v12x5; + } + if (6 < k) { + out[102] = v12x6; + } + if (7 < k) { + out[103] = v12x7; + } + w12 += 8; + const int8_t v13x0 = 0 < k ? w13[0] : izp; + const int8_t v13x1 = 1 < k ? w13[1] : izp; + const int8_t v13x2 = 2 < k ? w13[2] : izp; + const int8_t v13x3 = 3 < k ? w13[3] : izp; + const int8_t v13x4 = 4 < k ? w13[4] : izp; + const int8_t v13x5 = 5 < k ? w13[5] : izp; + const int8_t v13x6 = 6 < k ? w13[6] : izp; + const int8_t v13x7 = 7 < k ? w13[7] : izp; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + if (0 < k) { + out[104] = v13x0; + } + if (1 < k) { + out[105] = v13x1; + } + if (2 < k) { + out[106] = v13x2; + } + if (3 < k) { + out[107] = v13x3; + } + if (4 < k) { + out[108] = v13x4; + } + if (5 < k) { + out[109] = v13x5; + } + if (6 < k) { + out[110] = v13x6; + } + if (7 < k) { + out[111] = v13x7; + } + w13 += 8; + const int8_t v14x0 = 0 < k ? w14[0] : izp; + const int8_t v14x1 = 1 < k ? w14[1] : izp; + const int8_t v14x2 = 2 < k ? w14[2] : izp; + const int8_t v14x3 = 3 < k ? w14[3] : izp; + const int8_t v14x4 = 4 < k ? w14[4] : izp; + const int8_t v14x5 = 5 < k ? w14[5] : izp; + const int8_t v14x6 = 6 < k ? w14[6] : izp; + const int8_t v14x7 = 7 < k ? w14[7] : izp; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + if (0 < k) { + out[112] = v14x0; + } + if (1 < k) { + out[113] = v14x1; + } + if (2 < k) { + out[114] = v14x2; + } + if (3 < k) { + out[115] = v14x3; + } + if (4 < k) { + out[116] = v14x4; + } + if (5 < k) { + out[117] = v14x5; + } + if (6 < k) { + out[118] = v14x6; + } + if (7 < k) { + out[119] = v14x7; + } + w14 += 8; + const int8_t v15x0 = 0 < k ? w15[0] : izp; + const int8_t v15x1 = 1 < k ? w15[1] : izp; + const int8_t v15x2 = 2 < k ? w15[2] : izp; + const int8_t v15x3 = 3 < k ? w15[3] : izp; + const int8_t v15x4 = 4 < k ? w15[4] : izp; + const int8_t v15x5 = 5 < k ? w15[5] : izp; + const int8_t v15x6 = 6 < k ? w15[6] : izp; + const int8_t v15x7 = 7 < k ? w15[7] : izp; + ksum15 += (uint32_t) v15x0; + ksum15 += (uint32_t) v15x1; + ksum15 += (uint32_t) v15x2; + ksum15 += (uint32_t) v15x3; + ksum15 += (uint32_t) v15x4; + ksum15 += (uint32_t) v15x5; + ksum15 += (uint32_t) v15x6; + ksum15 += (uint32_t) v15x7; + if (0 < k) { + out[120] = v15x0; + } + if (1 < k) { + out[121] = v15x1; + } + if (2 < k) { + out[122] = v15x2; + } + if (3 < k) { + out[123] = v15x3; + } + if (4 < k) { + out[124] = v15x4; + } + if (5 < k) { + out[125] = v15x5; + } + if (6 < k) { + out[126] = v15x6; + } + if (7 < k) { + out[127] = v15x7; + } + w15 += 8; + out += 128; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + packed_b[8] -= ksum8 * izp; + packed_b[9] -= ksum9 * izp; + packed_b[10] -= ksum10 * izp; + packed_b[11] -= ksum11 * izp; + packed_b[12] -= ksum12 * izp; + packed_b[13] -= ksum13 * izp; + packed_b[14] -= ksum14 * izp; + packed_b[15] -= ksum15 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(int32_t); + + // NR remainder has less than 16 rows so last row is not loaded + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + uint32_t ksum8 = 0; + uint32_t ksum9 = 0; + uint32_t ksum10 = 0; + uint32_t ksum11 = 0; + uint32_t ksum12 = 0; + uint32_t ksum13 = 0; + uint32_t ksum14 = 0; + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + const int8_t v8x0 = w8[0]; + const int8_t v8x1 = w8[1]; + const int8_t v8x2 = w8[2]; + const int8_t v8x3 = w8[3]; + const int8_t v8x4 = w8[4]; + const int8_t v8x5 = w8[5]; + const int8_t v8x6 = w8[6]; + const int8_t v8x7 = w8[7]; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + out[64] = v8x0; + out[65] = v8x1; + out[66] = v8x2; + out[67] = v8x3; + out[68] = v8x4; + out[69] = v8x5; + out[70] = v8x6; + out[71] = v8x7; + w8 += 8; + const int8_t v9x0 = w9[0]; + const int8_t v9x1 = w9[1]; + const int8_t v9x2 = w9[2]; + const int8_t v9x3 = w9[3]; + const int8_t v9x4 = w9[4]; + const int8_t v9x5 = w9[5]; + const int8_t v9x6 = w9[6]; + const int8_t v9x7 = w9[7]; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + out[72] = v9x0; + out[73] = v9x1; + out[74] = v9x2; + out[75] = v9x3; + out[76] = v9x4; + out[77] = v9x5; + out[78] = v9x6; + out[79] = v9x7; + w9 += 8; + const int8_t v10x0 = w10[0]; + const int8_t v10x1 = w10[1]; + const int8_t v10x2 = w10[2]; + const int8_t v10x3 = w10[3]; + const int8_t v10x4 = w10[4]; + const int8_t v10x5 = w10[5]; + const int8_t v10x6 = w10[6]; + const int8_t v10x7 = w10[7]; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + out[80] = v10x0; + out[81] = v10x1; + out[82] = v10x2; + out[83] = v10x3; + out[84] = v10x4; + out[85] = v10x5; + out[86] = v10x6; + out[87] = v10x7; + w10 += 8; + const int8_t v11x0 = w11[0]; + const int8_t v11x1 = w11[1]; + const int8_t v11x2 = w11[2]; + const int8_t v11x3 = w11[3]; + const int8_t v11x4 = w11[4]; + const int8_t v11x5 = w11[5]; + const int8_t v11x6 = w11[6]; + const int8_t v11x7 = w11[7]; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + out[88] = v11x0; + out[89] = v11x1; + out[90] = v11x2; + out[91] = v11x3; + out[92] = v11x4; + out[93] = v11x5; + out[94] = v11x6; + out[95] = v11x7; + w11 += 8; + const int8_t v12x0 = w12[0]; + const int8_t v12x1 = w12[1]; + const int8_t v12x2 = w12[2]; + const int8_t v12x3 = w12[3]; + const int8_t v12x4 = w12[4]; + const int8_t v12x5 = w12[5]; + const int8_t v12x6 = w12[6]; + const int8_t v12x7 = w12[7]; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + out[96] = v12x0; + out[97] = v12x1; + out[98] = v12x2; + out[99] = v12x3; + out[100] = v12x4; + out[101] = v12x5; + out[102] = v12x6; + out[103] = v12x7; + w12 += 8; + const int8_t v13x0 = w13[0]; + const int8_t v13x1 = w13[1]; + const int8_t v13x2 = w13[2]; + const int8_t v13x3 = w13[3]; + const int8_t v13x4 = w13[4]; + const int8_t v13x5 = w13[5]; + const int8_t v13x6 = w13[6]; + const int8_t v13x7 = w13[7]; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + out[104] = v13x0; + out[105] = v13x1; + out[106] = v13x2; + out[107] = v13x3; + out[108] = v13x4; + out[109] = v13x5; + out[110] = v13x6; + out[111] = v13x7; + w13 += 8; + const int8_t v14x0 = w14[0]; + const int8_t v14x1 = w14[1]; + const int8_t v14x2 = w14[2]; + const int8_t v14x3 = w14[3]; + const int8_t v14x4 = w14[4]; + const int8_t v14x5 = w14[5]; + const int8_t v14x6 = w14[6]; + const int8_t v14x7 = w14[7]; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + out[112] = v14x0; + out[113] = v14x1; + out[114] = v14x2; + out[115] = v14x3; + out[116] = v14x4; + out[117] = v14x5; + out[118] = v14x6; + out[119] = v14x7; + w14 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + const int8_t v8x0 = 0 < k ? w8[0] : izp; + const int8_t v8x1 = 1 < k ? w8[1] : izp; + const int8_t v8x2 = 2 < k ? w8[2] : izp; + const int8_t v8x3 = 3 < k ? w8[3] : izp; + const int8_t v8x4 = 4 < k ? w8[4] : izp; + const int8_t v8x5 = 5 < k ? w8[5] : izp; + const int8_t v8x6 = 6 < k ? w8[6] : izp; + const int8_t v8x7 = 7 < k ? w8[7] : izp; + ksum8 += (uint32_t) v8x0; + ksum8 += (uint32_t) v8x1; + ksum8 += (uint32_t) v8x2; + ksum8 += (uint32_t) v8x3; + ksum8 += (uint32_t) v8x4; + ksum8 += (uint32_t) v8x5; + ksum8 += (uint32_t) v8x6; + ksum8 += (uint32_t) v8x7; + if (0 < k) { + out[64] = v8x0; + } + if (1 < k) { + out[65] = v8x1; + } + if (2 < k) { + out[66] = v8x2; + } + if (3 < k) { + out[67] = v8x3; + } + if (4 < k) { + out[68] = v8x4; + } + if (5 < k) { + out[69] = v8x5; + } + if (6 < k) { + out[70] = v8x6; + } + if (7 < k) { + out[71] = v8x7; + } + w8 += 8; + const int8_t v9x0 = 0 < k ? w9[0] : izp; + const int8_t v9x1 = 1 < k ? w9[1] : izp; + const int8_t v9x2 = 2 < k ? w9[2] : izp; + const int8_t v9x3 = 3 < k ? w9[3] : izp; + const int8_t v9x4 = 4 < k ? w9[4] : izp; + const int8_t v9x5 = 5 < k ? w9[5] : izp; + const int8_t v9x6 = 6 < k ? w9[6] : izp; + const int8_t v9x7 = 7 < k ? w9[7] : izp; + ksum9 += (uint32_t) v9x0; + ksum9 += (uint32_t) v9x1; + ksum9 += (uint32_t) v9x2; + ksum9 += (uint32_t) v9x3; + ksum9 += (uint32_t) v9x4; + ksum9 += (uint32_t) v9x5; + ksum9 += (uint32_t) v9x6; + ksum9 += (uint32_t) v9x7; + if (0 < k) { + out[72] = v9x0; + } + if (1 < k) { + out[73] = v9x1; + } + if (2 < k) { + out[74] = v9x2; + } + if (3 < k) { + out[75] = v9x3; + } + if (4 < k) { + out[76] = v9x4; + } + if (5 < k) { + out[77] = v9x5; + } + if (6 < k) { + out[78] = v9x6; + } + if (7 < k) { + out[79] = v9x7; + } + w9 += 8; + const int8_t v10x0 = 0 < k ? w10[0] : izp; + const int8_t v10x1 = 1 < k ? w10[1] : izp; + const int8_t v10x2 = 2 < k ? w10[2] : izp; + const int8_t v10x3 = 3 < k ? w10[3] : izp; + const int8_t v10x4 = 4 < k ? w10[4] : izp; + const int8_t v10x5 = 5 < k ? w10[5] : izp; + const int8_t v10x6 = 6 < k ? w10[6] : izp; + const int8_t v10x7 = 7 < k ? w10[7] : izp; + ksum10 += (uint32_t) v10x0; + ksum10 += (uint32_t) v10x1; + ksum10 += (uint32_t) v10x2; + ksum10 += (uint32_t) v10x3; + ksum10 += (uint32_t) v10x4; + ksum10 += (uint32_t) v10x5; + ksum10 += (uint32_t) v10x6; + ksum10 += (uint32_t) v10x7; + if (0 < k) { + out[80] = v10x0; + } + if (1 < k) { + out[81] = v10x1; + } + if (2 < k) { + out[82] = v10x2; + } + if (3 < k) { + out[83] = v10x3; + } + if (4 < k) { + out[84] = v10x4; + } + if (5 < k) { + out[85] = v10x5; + } + if (6 < k) { + out[86] = v10x6; + } + if (7 < k) { + out[87] = v10x7; + } + w10 += 8; + const int8_t v11x0 = 0 < k ? w11[0] : izp; + const int8_t v11x1 = 1 < k ? w11[1] : izp; + const int8_t v11x2 = 2 < k ? w11[2] : izp; + const int8_t v11x3 = 3 < k ? w11[3] : izp; + const int8_t v11x4 = 4 < k ? w11[4] : izp; + const int8_t v11x5 = 5 < k ? w11[5] : izp; + const int8_t v11x6 = 6 < k ? w11[6] : izp; + const int8_t v11x7 = 7 < k ? w11[7] : izp; + ksum11 += (uint32_t) v11x0; + ksum11 += (uint32_t) v11x1; + ksum11 += (uint32_t) v11x2; + ksum11 += (uint32_t) v11x3; + ksum11 += (uint32_t) v11x4; + ksum11 += (uint32_t) v11x5; + ksum11 += (uint32_t) v11x6; + ksum11 += (uint32_t) v11x7; + if (0 < k) { + out[88] = v11x0; + } + if (1 < k) { + out[89] = v11x1; + } + if (2 < k) { + out[90] = v11x2; + } + if (3 < k) { + out[91] = v11x3; + } + if (4 < k) { + out[92] = v11x4; + } + if (5 < k) { + out[93] = v11x5; + } + if (6 < k) { + out[94] = v11x6; + } + if (7 < k) { + out[95] = v11x7; + } + w11 += 8; + const int8_t v12x0 = 0 < k ? w12[0] : izp; + const int8_t v12x1 = 1 < k ? w12[1] : izp; + const int8_t v12x2 = 2 < k ? w12[2] : izp; + const int8_t v12x3 = 3 < k ? w12[3] : izp; + const int8_t v12x4 = 4 < k ? w12[4] : izp; + const int8_t v12x5 = 5 < k ? w12[5] : izp; + const int8_t v12x6 = 6 < k ? w12[6] : izp; + const int8_t v12x7 = 7 < k ? w12[7] : izp; + ksum12 += (uint32_t) v12x0; + ksum12 += (uint32_t) v12x1; + ksum12 += (uint32_t) v12x2; + ksum12 += (uint32_t) v12x3; + ksum12 += (uint32_t) v12x4; + ksum12 += (uint32_t) v12x5; + ksum12 += (uint32_t) v12x6; + ksum12 += (uint32_t) v12x7; + if (0 < k) { + out[96] = v12x0; + } + if (1 < k) { + out[97] = v12x1; + } + if (2 < k) { + out[98] = v12x2; + } + if (3 < k) { + out[99] = v12x3; + } + if (4 < k) { + out[100] = v12x4; + } + if (5 < k) { + out[101] = v12x5; + } + if (6 < k) { + out[102] = v12x6; + } + if (7 < k) { + out[103] = v12x7; + } + w12 += 8; + const int8_t v13x0 = 0 < k ? w13[0] : izp; + const int8_t v13x1 = 1 < k ? w13[1] : izp; + const int8_t v13x2 = 2 < k ? w13[2] : izp; + const int8_t v13x3 = 3 < k ? w13[3] : izp; + const int8_t v13x4 = 4 < k ? w13[4] : izp; + const int8_t v13x5 = 5 < k ? w13[5] : izp; + const int8_t v13x6 = 6 < k ? w13[6] : izp; + const int8_t v13x7 = 7 < k ? w13[7] : izp; + ksum13 += (uint32_t) v13x0; + ksum13 += (uint32_t) v13x1; + ksum13 += (uint32_t) v13x2; + ksum13 += (uint32_t) v13x3; + ksum13 += (uint32_t) v13x4; + ksum13 += (uint32_t) v13x5; + ksum13 += (uint32_t) v13x6; + ksum13 += (uint32_t) v13x7; + if (0 < k) { + out[104] = v13x0; + } + if (1 < k) { + out[105] = v13x1; + } + if (2 < k) { + out[106] = v13x2; + } + if (3 < k) { + out[107] = v13x3; + } + if (4 < k) { + out[108] = v13x4; + } + if (5 < k) { + out[109] = v13x5; + } + if (6 < k) { + out[110] = v13x6; + } + if (7 < k) { + out[111] = v13x7; + } + w13 += 8; + const int8_t v14x0 = 0 < k ? w14[0] : izp; + const int8_t v14x1 = 1 < k ? w14[1] : izp; + const int8_t v14x2 = 2 < k ? w14[2] : izp; + const int8_t v14x3 = 3 < k ? w14[3] : izp; + const int8_t v14x4 = 4 < k ? w14[4] : izp; + const int8_t v14x5 = 5 < k ? w14[5] : izp; + const int8_t v14x6 = 6 < k ? w14[6] : izp; + const int8_t v14x7 = 7 < k ? w14[7] : izp; + ksum14 += (uint32_t) v14x0; + ksum14 += (uint32_t) v14x1; + ksum14 += (uint32_t) v14x2; + ksum14 += (uint32_t) v14x3; + ksum14 += (uint32_t) v14x4; + ksum14 += (uint32_t) v14x5; + ksum14 += (uint32_t) v14x6; + ksum14 += (uint32_t) v14x7; + if (0 < k) { + out[112] = v14x0; + } + if (1 < k) { + out[113] = v14x1; + } + if (2 < k) { + out[114] = v14x2; + } + if (3 < k) { + out[115] = v14x3; + } + if (4 < k) { + out[116] = v14x4; + } + if (5 < k) { + out[117] = v14x5; + } + if (6 < k) { + out[118] = v14x6; + } + if (7 < k) { + out[119] = v14x7; + } + w14 += 8; + out += 128; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + packed_b[8] -= ksum8 * izp; + packed_b[9] -= ksum9 * izp; + packed_b[10] -= ksum10 * izp; + packed_b[11] -= ksum11 * izp; + packed_b[12] -= ksum12 * izp; + packed_b[13] -= ksum13 * izp; + packed_b[14] -= ksum14 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c new file mode 100644 index 000000000000..45c63d2b2689 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c @@ -0,0 +1,444 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnniint8.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if defined(__x86_64__) + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +#endif // defined(__x86_64__) +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c new file mode 100644 index 000000000000..d954e4de632d --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c @@ -0,0 +1,1175 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-scalar.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const uint32_t izp = params ? (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + ((int32_t*) out)[0] = b[0]; + ((int32_t*) out)[1] = b[1]; + ((int32_t*) out)[2] = b[2]; + ((int32_t*) out)[3] = b[3]; + ((int32_t*) out)[4] = b[4]; + ((int32_t*) out)[5] = b[5]; + ((int32_t*) out)[6] = b[6]; + ((int32_t*) out)[7] = b[7]; + b += 8; + } else { + ((int32_t*) out)[0] = 0; + ((int32_t*) out)[1] = 0; + ((int32_t*) out)[2] = 0; + ((int32_t*) out)[3] = 0; + ((int32_t*) out)[4] = 0; + ((int32_t*) out)[5] = 0; + ((int32_t*) out)[6] = 0; + ((int32_t*) out)[7] = 0; + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + uint32_t ksum7 = 0; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + const int8_t v7x0 = w7[0]; + const int8_t v7x1 = w7[1]; + const int8_t v7x2 = w7[2]; + const int8_t v7x3 = w7[3]; + const int8_t v7x4 = w7[4]; + const int8_t v7x5 = w7[5]; + const int8_t v7x6 = w7[6]; + const int8_t v7x7 = w7[7]; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + out[56] = v7x0; + out[57] = v7x1; + out[58] = v7x2; + out[59] = v7x3; + out[60] = v7x4; + out[61] = v7x5; + out[62] = v7x6; + out[63] = v7x7; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + const int8_t v7x0 = 0 < k ? w7[0] : izp; + const int8_t v7x1 = 1 < k ? w7[1] : izp; + const int8_t v7x2 = 2 < k ? w7[2] : izp; + const int8_t v7x3 = 3 < k ? w7[3] : izp; + const int8_t v7x4 = 4 < k ? w7[4] : izp; + const int8_t v7x5 = 5 < k ? w7[5] : izp; + const int8_t v7x6 = 6 < k ? w7[6] : izp; + const int8_t v7x7 = 7 < k ? w7[7] : izp; + ksum7 += (uint32_t) v7x0; + ksum7 += (uint32_t) v7x1; + ksum7 += (uint32_t) v7x2; + ksum7 += (uint32_t) v7x3; + ksum7 += (uint32_t) v7x4; + ksum7 += (uint32_t) v7x5; + ksum7 += (uint32_t) v7x6; + ksum7 += (uint32_t) v7x7; + if (0 < k) { + out[56] = v7x0; + } + if (1 < k) { + out[57] = v7x1; + } + if (2 < k) { + out[58] = v7x2; + } + if (3 < k) { + out[59] = v7x3; + } + if (4 < k) { + out[60] = v7x4; + } + if (5 < k) { + out[61] = v7x5; + } + if (6 < k) { + out[62] = v7x6; + } + if (7 < k) { + out[63] = v7x7; + } + w7 += 8; + out += 64; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + packed_b[7] -= ksum7 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(int32_t); + + // NR remainder has less than 8 rows so last row is not loaded + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + + uint32_t ksum0 = 0; + uint32_t ksum1 = 0; + uint32_t ksum2 = 0; + uint32_t ksum3 = 0; + uint32_t ksum4 = 0; + uint32_t ksum5 = 0; + uint32_t ksum6 = 0; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + const int8_t v0x0 = w0[0]; + const int8_t v0x1 = w0[1]; + const int8_t v0x2 = w0[2]; + const int8_t v0x3 = w0[3]; + const int8_t v0x4 = w0[4]; + const int8_t v0x5 = w0[5]; + const int8_t v0x6 = w0[6]; + const int8_t v0x7 = w0[7]; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + out[0] = v0x0; + out[1] = v0x1; + out[2] = v0x2; + out[3] = v0x3; + out[4] = v0x4; + out[5] = v0x5; + out[6] = v0x6; + out[7] = v0x7; + w0 += 8; + const int8_t v1x0 = w1[0]; + const int8_t v1x1 = w1[1]; + const int8_t v1x2 = w1[2]; + const int8_t v1x3 = w1[3]; + const int8_t v1x4 = w1[4]; + const int8_t v1x5 = w1[5]; + const int8_t v1x6 = w1[6]; + const int8_t v1x7 = w1[7]; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + out[8] = v1x0; + out[9] = v1x1; + out[10] = v1x2; + out[11] = v1x3; + out[12] = v1x4; + out[13] = v1x5; + out[14] = v1x6; + out[15] = v1x7; + w1 += 8; + const int8_t v2x0 = w2[0]; + const int8_t v2x1 = w2[1]; + const int8_t v2x2 = w2[2]; + const int8_t v2x3 = w2[3]; + const int8_t v2x4 = w2[4]; + const int8_t v2x5 = w2[5]; + const int8_t v2x6 = w2[6]; + const int8_t v2x7 = w2[7]; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + out[16] = v2x0; + out[17] = v2x1; + out[18] = v2x2; + out[19] = v2x3; + out[20] = v2x4; + out[21] = v2x5; + out[22] = v2x6; + out[23] = v2x7; + w2 += 8; + const int8_t v3x0 = w3[0]; + const int8_t v3x1 = w3[1]; + const int8_t v3x2 = w3[2]; + const int8_t v3x3 = w3[3]; + const int8_t v3x4 = w3[4]; + const int8_t v3x5 = w3[5]; + const int8_t v3x6 = w3[6]; + const int8_t v3x7 = w3[7]; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + out[24] = v3x0; + out[25] = v3x1; + out[26] = v3x2; + out[27] = v3x3; + out[28] = v3x4; + out[29] = v3x5; + out[30] = v3x6; + out[31] = v3x7; + w3 += 8; + const int8_t v4x0 = w4[0]; + const int8_t v4x1 = w4[1]; + const int8_t v4x2 = w4[2]; + const int8_t v4x3 = w4[3]; + const int8_t v4x4 = w4[4]; + const int8_t v4x5 = w4[5]; + const int8_t v4x6 = w4[6]; + const int8_t v4x7 = w4[7]; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + out[32] = v4x0; + out[33] = v4x1; + out[34] = v4x2; + out[35] = v4x3; + out[36] = v4x4; + out[37] = v4x5; + out[38] = v4x6; + out[39] = v4x7; + w4 += 8; + const int8_t v5x0 = w5[0]; + const int8_t v5x1 = w5[1]; + const int8_t v5x2 = w5[2]; + const int8_t v5x3 = w5[3]; + const int8_t v5x4 = w5[4]; + const int8_t v5x5 = w5[5]; + const int8_t v5x6 = w5[6]; + const int8_t v5x7 = w5[7]; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + out[40] = v5x0; + out[41] = v5x1; + out[42] = v5x2; + out[43] = v5x3; + out[44] = v5x4; + out[45] = v5x5; + out[46] = v5x6; + out[47] = v5x7; + w5 += 8; + const int8_t v6x0 = w6[0]; + const int8_t v6x1 = w6[1]; + const int8_t v6x2 = w6[2]; + const int8_t v6x3 = w6[3]; + const int8_t v6x4 = w6[4]; + const int8_t v6x5 = w6[5]; + const int8_t v6x6 = w6[6]; + const int8_t v6x7 = w6[7]; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + out[48] = v6x0; + out[49] = v6x1; + out[50] = v6x2; + out[51] = v6x3; + out[52] = v6x4; + out[53] = v6x5; + out[54] = v6x6; + out[55] = v6x7; + w6 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + const int8_t v0x0 = 0 < k ? w0[0] : izp; + const int8_t v0x1 = 1 < k ? w0[1] : izp; + const int8_t v0x2 = 2 < k ? w0[2] : izp; + const int8_t v0x3 = 3 < k ? w0[3] : izp; + const int8_t v0x4 = 4 < k ? w0[4] : izp; + const int8_t v0x5 = 5 < k ? w0[5] : izp; + const int8_t v0x6 = 6 < k ? w0[6] : izp; + const int8_t v0x7 = 7 < k ? w0[7] : izp; + ksum0 += (uint32_t) v0x0; + ksum0 += (uint32_t) v0x1; + ksum0 += (uint32_t) v0x2; + ksum0 += (uint32_t) v0x3; + ksum0 += (uint32_t) v0x4; + ksum0 += (uint32_t) v0x5; + ksum0 += (uint32_t) v0x6; + ksum0 += (uint32_t) v0x7; + if (0 < k) { + out[0] = v0x0; + } + if (1 < k) { + out[1] = v0x1; + } + if (2 < k) { + out[2] = v0x2; + } + if (3 < k) { + out[3] = v0x3; + } + if (4 < k) { + out[4] = v0x4; + } + if (5 < k) { + out[5] = v0x5; + } + if (6 < k) { + out[6] = v0x6; + } + if (7 < k) { + out[7] = v0x7; + } + w0 += 8; + const int8_t v1x0 = 0 < k ? w1[0] : izp; + const int8_t v1x1 = 1 < k ? w1[1] : izp; + const int8_t v1x2 = 2 < k ? w1[2] : izp; + const int8_t v1x3 = 3 < k ? w1[3] : izp; + const int8_t v1x4 = 4 < k ? w1[4] : izp; + const int8_t v1x5 = 5 < k ? w1[5] : izp; + const int8_t v1x6 = 6 < k ? w1[6] : izp; + const int8_t v1x7 = 7 < k ? w1[7] : izp; + ksum1 += (uint32_t) v1x0; + ksum1 += (uint32_t) v1x1; + ksum1 += (uint32_t) v1x2; + ksum1 += (uint32_t) v1x3; + ksum1 += (uint32_t) v1x4; + ksum1 += (uint32_t) v1x5; + ksum1 += (uint32_t) v1x6; + ksum1 += (uint32_t) v1x7; + if (0 < k) { + out[8] = v1x0; + } + if (1 < k) { + out[9] = v1x1; + } + if (2 < k) { + out[10] = v1x2; + } + if (3 < k) { + out[11] = v1x3; + } + if (4 < k) { + out[12] = v1x4; + } + if (5 < k) { + out[13] = v1x5; + } + if (6 < k) { + out[14] = v1x6; + } + if (7 < k) { + out[15] = v1x7; + } + w1 += 8; + const int8_t v2x0 = 0 < k ? w2[0] : izp; + const int8_t v2x1 = 1 < k ? w2[1] : izp; + const int8_t v2x2 = 2 < k ? w2[2] : izp; + const int8_t v2x3 = 3 < k ? w2[3] : izp; + const int8_t v2x4 = 4 < k ? w2[4] : izp; + const int8_t v2x5 = 5 < k ? w2[5] : izp; + const int8_t v2x6 = 6 < k ? w2[6] : izp; + const int8_t v2x7 = 7 < k ? w2[7] : izp; + ksum2 += (uint32_t) v2x0; + ksum2 += (uint32_t) v2x1; + ksum2 += (uint32_t) v2x2; + ksum2 += (uint32_t) v2x3; + ksum2 += (uint32_t) v2x4; + ksum2 += (uint32_t) v2x5; + ksum2 += (uint32_t) v2x6; + ksum2 += (uint32_t) v2x7; + if (0 < k) { + out[16] = v2x0; + } + if (1 < k) { + out[17] = v2x1; + } + if (2 < k) { + out[18] = v2x2; + } + if (3 < k) { + out[19] = v2x3; + } + if (4 < k) { + out[20] = v2x4; + } + if (5 < k) { + out[21] = v2x5; + } + if (6 < k) { + out[22] = v2x6; + } + if (7 < k) { + out[23] = v2x7; + } + w2 += 8; + const int8_t v3x0 = 0 < k ? w3[0] : izp; + const int8_t v3x1 = 1 < k ? w3[1] : izp; + const int8_t v3x2 = 2 < k ? w3[2] : izp; + const int8_t v3x3 = 3 < k ? w3[3] : izp; + const int8_t v3x4 = 4 < k ? w3[4] : izp; + const int8_t v3x5 = 5 < k ? w3[5] : izp; + const int8_t v3x6 = 6 < k ? w3[6] : izp; + const int8_t v3x7 = 7 < k ? w3[7] : izp; + ksum3 += (uint32_t) v3x0; + ksum3 += (uint32_t) v3x1; + ksum3 += (uint32_t) v3x2; + ksum3 += (uint32_t) v3x3; + ksum3 += (uint32_t) v3x4; + ksum3 += (uint32_t) v3x5; + ksum3 += (uint32_t) v3x6; + ksum3 += (uint32_t) v3x7; + if (0 < k) { + out[24] = v3x0; + } + if (1 < k) { + out[25] = v3x1; + } + if (2 < k) { + out[26] = v3x2; + } + if (3 < k) { + out[27] = v3x3; + } + if (4 < k) { + out[28] = v3x4; + } + if (5 < k) { + out[29] = v3x5; + } + if (6 < k) { + out[30] = v3x6; + } + if (7 < k) { + out[31] = v3x7; + } + w3 += 8; + const int8_t v4x0 = 0 < k ? w4[0] : izp; + const int8_t v4x1 = 1 < k ? w4[1] : izp; + const int8_t v4x2 = 2 < k ? w4[2] : izp; + const int8_t v4x3 = 3 < k ? w4[3] : izp; + const int8_t v4x4 = 4 < k ? w4[4] : izp; + const int8_t v4x5 = 5 < k ? w4[5] : izp; + const int8_t v4x6 = 6 < k ? w4[6] : izp; + const int8_t v4x7 = 7 < k ? w4[7] : izp; + ksum4 += (uint32_t) v4x0; + ksum4 += (uint32_t) v4x1; + ksum4 += (uint32_t) v4x2; + ksum4 += (uint32_t) v4x3; + ksum4 += (uint32_t) v4x4; + ksum4 += (uint32_t) v4x5; + ksum4 += (uint32_t) v4x6; + ksum4 += (uint32_t) v4x7; + if (0 < k) { + out[32] = v4x0; + } + if (1 < k) { + out[33] = v4x1; + } + if (2 < k) { + out[34] = v4x2; + } + if (3 < k) { + out[35] = v4x3; + } + if (4 < k) { + out[36] = v4x4; + } + if (5 < k) { + out[37] = v4x5; + } + if (6 < k) { + out[38] = v4x6; + } + if (7 < k) { + out[39] = v4x7; + } + w4 += 8; + const int8_t v5x0 = 0 < k ? w5[0] : izp; + const int8_t v5x1 = 1 < k ? w5[1] : izp; + const int8_t v5x2 = 2 < k ? w5[2] : izp; + const int8_t v5x3 = 3 < k ? w5[3] : izp; + const int8_t v5x4 = 4 < k ? w5[4] : izp; + const int8_t v5x5 = 5 < k ? w5[5] : izp; + const int8_t v5x6 = 6 < k ? w5[6] : izp; + const int8_t v5x7 = 7 < k ? w5[7] : izp; + ksum5 += (uint32_t) v5x0; + ksum5 += (uint32_t) v5x1; + ksum5 += (uint32_t) v5x2; + ksum5 += (uint32_t) v5x3; + ksum5 += (uint32_t) v5x4; + ksum5 += (uint32_t) v5x5; + ksum5 += (uint32_t) v5x6; + ksum5 += (uint32_t) v5x7; + if (0 < k) { + out[40] = v5x0; + } + if (1 < k) { + out[41] = v5x1; + } + if (2 < k) { + out[42] = v5x2; + } + if (3 < k) { + out[43] = v5x3; + } + if (4 < k) { + out[44] = v5x4; + } + if (5 < k) { + out[45] = v5x5; + } + if (6 < k) { + out[46] = v5x6; + } + if (7 < k) { + out[47] = v5x7; + } + w5 += 8; + const int8_t v6x0 = 0 < k ? w6[0] : izp; + const int8_t v6x1 = 1 < k ? w6[1] : izp; + const int8_t v6x2 = 2 < k ? w6[2] : izp; + const int8_t v6x3 = 3 < k ? w6[3] : izp; + const int8_t v6x4 = 4 < k ? w6[4] : izp; + const int8_t v6x5 = 5 < k ? w6[5] : izp; + const int8_t v6x6 = 6 < k ? w6[6] : izp; + const int8_t v6x7 = 7 < k ? w6[7] : izp; + ksum6 += (uint32_t) v6x0; + ksum6 += (uint32_t) v6x1; + ksum6 += (uint32_t) v6x2; + ksum6 += (uint32_t) v6x3; + ksum6 += (uint32_t) v6x4; + ksum6 += (uint32_t) v6x5; + ksum6 += (uint32_t) v6x6; + ksum6 += (uint32_t) v6x7; + if (0 < k) { + out[48] = v6x0; + } + if (1 < k) { + out[49] = v6x1; + } + if (2 < k) { + out[50] = v6x2; + } + if (3 < k) { + out[51] = v6x3; + } + if (4 < k) { + out[52] = v6x4; + } + if (5 < k) { + out[53] = v6x5; + } + if (6 < k) { + out[54] = v6x6; + } + if (7 < k) { + out[55] = v6x7; + } + w6 += 8; + out += 64; + } + + packed_b[0] -= ksum0 * izp; + packed_b[1] -= ksum1 * izp; + packed_b[2] -= ksum2 * izp; + packed_b[3] -= ksum3 * izp; + packed_b[4] -= ksum4 * izp; + packed_b[5] -= ksum5 * izp; + packed_b[6] -= ksum6 * izp; + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index d0d55e50e312..570273a66cd3 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -21,6 +21,14 @@ XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c4__scalar, 16, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x32c4__scalar, 32, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar, 64, 4, 1, 4, 1) +XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar, 8, 8, 1, 8, 1) +XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 8, 1) + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if XNN_ENABLE_AVXVNNIINT8 && XNN_ARCH_X86_64 +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnniint8, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8, 8, 8, 1, 8, 1) +#endif + #ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_QS8_UKERNEL_WITH_PARAMS diff --git a/src/subgraph.c b/src/subgraph.c index d227add16248..4bc18c1b15b8 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -13,11 +13,11 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" +#include "xnnpack/fp16.h" #include "xnnpack/hardware-config.h" #include "xnnpack/log.h" #include "xnnpack/math.h" diff --git a/src/subgraph/static-constant-pad.c b/src/subgraph/static-constant-pad.c index 5bad5e90fb4f..ab75daff4e02 100644 --- a/src/subgraph/static-constant-pad.c +++ b/src/subgraph/static-constant-pad.c @@ -9,9 +9,9 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/common.h" +#include "xnnpack/fp16.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" diff --git a/src/x8-packw/kr-avxvnniint8.c.in b/src/x8-packw/kr-avxvnniint8.c.in new file mode 100644 index 000000000000..fc6fb0a2f774 --- /dev/null +++ b/src/x8-packw/kr-avxvnniint8.c.in @@ -0,0 +1,401 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR == 8 +$assert KR == 8 +$assert TYPE in ["int8_t"] + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +$BITS = {"int8_t": 8}[TYPE] +$BTYPE = {"int8_t": "uint32_t"}[TYPE] +$WTYPE = {"int8_t": "int8_t"}[TYPE] +void xnn_qs${BITS}_packw_gemm_goi_ukernel_x${NR}c${KR}__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const ${WTYPE}* weights, + $if BITS == 8: + const int32_t* bias, + $else: + const ${WTYPE}* bias, + const void* scale, + ${WTYPE}* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + +// TODO: immintrin.h only provide _mm256_insert_epi64 for __x86_64__ +#if defined(__x86_64__) + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + $if BITS == 8: + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += ${NR}; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + $if BTYPE == TYPE: + out += ${NR}; + $else: + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w${NR-1}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = *b++; + $else: + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } else { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = 0; + $else: + *((${BTYPE}*) out) = 0; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } + $if BTYPE == TYPE: + out += (${NR} - n); + $else: + out += (${NR} - n) * sizeof(${BTYPE}); + + $if NR > 2: + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + } + + $if BITS == 8: + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(const int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(const int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +#endif // defined(__x86_64__) +} diff --git a/src/xnnpack/fp16.h b/src/xnnpack/fp16.h new file mode 100644 index 000000000000..dc3b47aa8dc2 --- /dev/null +++ b/src/xnnpack/fp16.h @@ -0,0 +1,179 @@ +#ifndef THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ +#define THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ + +#include + +// This file is an excerpt from https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h, +// including only the minimal functionality we need in XNNPACK. This works around some issues +// that we haven't been able to fix upstream (https://github.com/Maratyszcza/FP16/pull/32). See also: +// - https://github.com/microsoft/onnxruntime/pull/22294/files +// - https://github.com/google/XNNPACK/issues/6989 +// We also don't need a lot of the functionality in the upstream library. + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32 = { w }; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32 = { f }; + return fp32.as_bits; +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + const uint32_t w = fp32_to_bits(f); + const float abs_f = fp32_from_bits(w & UINT32_C(0x7FFFFFFF)); + float base = (abs_f * scale_to_inf) * scale_to_zero; + + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + + +#endif // THIRD_PARTY_XNNPACK_SRC_XNNPACK_FP16_H_ diff --git a/src/xnnpack/math.h b/src/xnnpack/math.h index dab6e0378533..893e038a3126 100644 --- a/src/xnnpack/math.h +++ b/src/xnnpack/math.h @@ -18,8 +18,8 @@ #include // For _rotl. #endif -#include #include "xnnpack/common.h" +#include "xnnpack/fp16.h" // stdlib.h from Windows 10 SDK defines min & max macros. // Undefine them before defining the corresponding functions. diff --git a/src/xnnpack/quantization.h b/src/xnnpack/quantization.h index 95f832bada59..0884ba41a2b6 100644 --- a/src/xnnpack/quantization.h +++ b/src/xnnpack/quantization.h @@ -9,7 +9,6 @@ #include #include -#include #include "xnnpack/math.h" #include "xnnpack/microparams.h" diff --git a/src/xnnpack/simd/f16-scalar.h b/src/xnnpack/simd/f16-scalar.h index c24076b282be..e95425c5b9b4 100644 --- a/src/xnnpack/simd/f16-scalar.h +++ b/src/xnnpack/simd/f16-scalar.h @@ -11,8 +11,8 @@ #include #include -#include #include "xnnpack/common.h" +#include "xnnpack/fp16.h" // SIMD vector type for f16 using SCALAR. typedef uint16_t xnn_simd_f16_t; diff --git a/test/BUILD.bazel b/test/BUILD.bazel index b36d2b0c0f8e..603357df6d64 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -26,11 +26,11 @@ load( MICROKERNEL_TEST_DEPS = [ ":next_prime", ":replicable_random_device", - "@FP16", "//:aligned_allocator", "//:all_microkernels", "//:allocator", "//:common", + "//:fp16", "//:isa_checks", "//:math", "//:memory", @@ -46,12 +46,12 @@ MICROKERNEL_TEST_DEPS = [ OPERATOR_TEST_DEPS = [ ":replicable_random_device", - "@FP16", "@pthreadpool", "//:aligned_allocator", "//:allocator", "//:cache", "//:common", + "//:fp16", "//:internal", "//:math", "//:microkernel_configs", @@ -189,8 +189,8 @@ sh_test( copts = xnnpack_simd_copts_for_arch(arch), deps = [ ":replicable_random_device", - "@FP16", "//:common", + "//:fp16", "//:isa_checks", "//:microkernels_h", ], @@ -1769,7 +1769,6 @@ xnnpack_cxx_library( deps = [ ":replicable_random_device", ":subgraph_unary_tester", - "@FP16", "//:XNNPACK", "//:math", "//:node_type", @@ -1885,7 +1884,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1915,7 +1913,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1932,7 +1929,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -1953,7 +1949,6 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1974,7 +1969,6 @@ xnnpack_unit_test( shard_count = 5, deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operator_utils", @@ -1991,7 +1985,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2007,7 +2000,6 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2026,7 +2018,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2046,7 +2037,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2068,7 +2058,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2086,7 +2075,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2104,7 +2092,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2121,7 +2108,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2138,7 +2124,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operator_utils", @@ -2155,7 +2140,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2185,7 +2169,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2201,7 +2184,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -2230,7 +2212,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:node_type", "//:operators", @@ -2348,7 +2329,6 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", - "@FP16", "//:XNNPACK", "//:allocation_type", "//:allocator", diff --git a/test/abs.cc b/test/abs.cc index ad1e84012931..7aaabe561396 100644 --- a/test/abs.cc +++ b/test/abs.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/average-pooling-2d.cc b/test/average-pooling-2d.cc index 4602f2aded50..8e9d2b58c067 100644 --- a/test/average-pooling-2d.cc +++ b/test/average-pooling-2d.cc @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/avgpool-microkernel-tester.h b/test/avgpool-microkernel-tester.h index 35db8fe1bb40..68e4bd4fb395 100644 --- a/test/avgpool-microkernel-tester.h +++ b/test/avgpool-microkernel-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/bankers-rounding.cc b/test/bankers-rounding.cc index 0021bdad455c..fe4ab88f1539 100644 --- a/test/bankers-rounding.cc +++ b/test/bankers-rounding.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/batch-matrix-multiply.cc b/test/batch-matrix-multiply.cc index 55fbe9eeb3e8..53bb93df1831 100644 --- a/test/batch-matrix-multiply.cc +++ b/test/batch-matrix-multiply.cc @@ -20,7 +20,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/ceiling.cc b/test/ceiling.cc index 01168b81e60f..16c1f9c4ce6c 100644 --- a/test/ceiling.cc +++ b/test/ceiling.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/clamp.cc b/test/clamp.cc index f9c87f57769f..e5db94311f05 100644 --- a/test/clamp.cc +++ b/test/clamp.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate2.cc b/test/concatenate2.cc index 06c6c3518606..245f114fa3ac 100644 --- a/test/concatenate2.cc +++ b/test/concatenate2.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate3.cc b/test/concatenate3.cc index 9a7888effff2..cb4f56a4b335 100644 --- a/test/concatenate3.cc +++ b/test/concatenate3.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate4.cc b/test/concatenate4.cc index 0d6ac91a8384..992ccf9668b9 100644 --- a/test/concatenate4.cc +++ b/test/concatenate4.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/concatenate5.cc b/test/concatenate5.cc index 8dfea193ce74..2934812315ce 100644 --- a/test/concatenate5.cc +++ b/test/concatenate5.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/conv-hwc2chw-microkernel-tester.h b/test/conv-hwc2chw-microkernel-tester.h index d354992bab4a..eaca4b861162 100644 --- a/test/conv-hwc2chw-microkernel-tester.h +++ b/test/conv-hwc2chw-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/convert-operator-tester.h b/test/convert-operator-tester.h index 78724161ed63..9a9641a32d96 100644 --- a/test/convert-operator-tester.h +++ b/test/convert-operator-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" diff --git a/test/convert.cc b/test/convert.cc index 8b7933b1b990..5c418403e185 100644 --- a/test/convert.cc +++ b/test/convert.cc @@ -12,7 +12,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/convolution-2d.cc b/test/convolution-2d.cc index a029865d3098..8e85f6cb9163 100644 --- a/test/convolution-2d.cc +++ b/test/convolution-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h index ad155b26759e..0a59b20b5b86 100644 --- a/test/convolution-operator-tester.h +++ b/test/convolution-operator-tester.h @@ -21,7 +21,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/cache.h" diff --git a/test/copy.cc b/test/copy.cc index 8683530503bd..a19cb3d34baf 100644 --- a/test/copy.cc +++ b/test/copy.cc @@ -11,7 +11,6 @@ #include // For std::unique_ptr. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/deconvolution-2d.cc b/test/deconvolution-2d.cc index 1d7d53bc314f..6e4f6cda19ce 100644 --- a/test/deconvolution-2d.cc +++ b/test/deconvolution-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" diff --git a/test/deconvolution-operator-tester.h b/test/deconvolution-operator-tester.h index 21d108ed9a48..c74e75b66500 100644 --- a/test/deconvolution-operator-tester.h +++ b/test/deconvolution-operator-tester.h @@ -23,7 +23,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" diff --git a/test/depth-to-space-2d.cc b/test/depth-to-space-2d.cc index 924f2e189990..f20ee1de3a0d 100644 --- a/test/depth-to-space-2d.cc +++ b/test/depth-to-space-2d.cc @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/depthwise-convolution-2d.cc b/test/depthwise-convolution-2d.cc index 2f008078792b..923c0e677624 100644 --- a/test/depthwise-convolution-2d.cc +++ b/test/depthwise-convolution-2d.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/dwconv-microkernel-tester.cc b/test/dwconv-microkernel-tester.cc index d29e22e7572b..942e5d284d26 100644 --- a/test/dwconv-microkernel-tester.cc +++ b/test/dwconv-microkernel-tester.cc @@ -21,7 +21,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/dwconv2d-microkernel-tester.h b/test/dwconv2d-microkernel-tester.h index 2af4605c7bcd..15a8438306c8 100644 --- a/test/dwconv2d-microkernel-tester.h +++ b/test/dwconv2d-microkernel-tester.h @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/elu.cc b/test/elu.cc index 1ad5b23fbd68..ccb49d61d68f 100644 --- a/test/elu.cc +++ b/test/elu.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split2.cc b/test/even-split2.cc index 221b5f00f09b..ab00d5f777d4 100644 --- a/test/even-split2.cc +++ b/test/even-split2.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split3.cc b/test/even-split3.cc index 99979aa0965a..cd697a325ab2 100644 --- a/test/even-split3.cc +++ b/test/even-split3.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/even-split4.cc b/test/even-split4.cc index 13f872aaa0b3..22d8ab52f825 100644 --- a/test/even-split4.cc +++ b/test/even-split4.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/f16-simd-scalar.cc b/test/f16-simd-scalar.cc index d5d100ede7a3..3fa9cda6100e 100644 --- a/test/f16-simd-scalar.cc +++ b/test/f16-simd-scalar.cc @@ -17,9 +17,9 @@ #include #include -#include #include "xnnpack/isa-checks.h" #include "xnnpack/simd/f16-scalar.h" +#include "xnnpack/fp16.h" #include "replicable_random_device.h" namespace xnnpack { diff --git a/test/f16-simd.cc.in b/test/f16-simd.cc.in index c6de215ce7f8..84ebd2d78fbd 100644 --- a/test/f16-simd.cc.in +++ b/test/f16-simd.cc.in @@ -19,9 +19,9 @@ $if ARCH_MACRO: #include #include -#include #include "xnnpack/isa-checks.h" #include "xnnpack/simd/f16-${ARCH}.h" +#include "xnnpack/fp16.h" #include "replicable_random_device.h" namespace xnnpack { diff --git a/test/floor.cc b/test/floor.cc index 0f62c175ffdd..790b668d726d 100644 --- a/test/floor.cc +++ b/test/floor.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h index a84bc876da78..c7e5695ca421 100644 --- a/test/fully-connected-operator-tester.h +++ b/test/fully-connected-operator-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" diff --git a/test/fully-connected.cc b/test/fully-connected.cc index 54624914ce08..a8a11097c10f 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/gavgpool-cw-microkernel-tester.h b/test/gavgpool-cw-microkernel-tester.h index 077c7c0775b0..fae2cb8e9d24 100644 --- a/test/gavgpool-cw-microkernel-tester.h +++ b/test/gavgpool-cw-microkernel-tester.h @@ -15,8 +15,8 @@ #include #include -#include #include "xnnpack.h" +#include "xnnpack/fp16.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" #include "replicable_random_device.h" diff --git a/test/gavgpool-microkernel-tester.h b/test/gavgpool-microkernel-tester.h index 78a6db9604c4..723bf879e948 100644 --- a/test/gavgpool-microkernel-tester.h +++ b/test/gavgpool-microkernel-tester.h @@ -19,7 +19,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 2e5bf67566e0..8d3d90b157da 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -14,8 +14,6 @@ #include #include -#include -#include #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/aligned-allocator.h" diff --git a/test/global-average-pooling-1d.cc b/test/global-average-pooling-1d.cc index 7cd48566976a..c7a472835e47 100644 --- a/test/global-average-pooling-1d.cc +++ b/test/global-average-pooling-1d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-average-pooling-2d.cc b/test/global-average-pooling-2d.cc index 9a51ac9aa160..7f83beb8fb32 100644 --- a/test/global-average-pooling-2d.cc +++ b/test/global-average-pooling-2d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-sum-pooling-1d.cc b/test/global-sum-pooling-1d.cc index fcf429581e39..45785d8fa370 100644 --- a/test/global-sum-pooling-1d.cc +++ b/test/global-sum-pooling-1d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/global-sum-pooling-2d.cc b/test/global-sum-pooling-2d.cc index cbe478483809..8d618494db59 100644 --- a/test/global-sum-pooling-2d.cc +++ b/test/global-sum-pooling-2d.cc @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/hardswish.cc b/test/hardswish.cc index 515e9e62faa6..728947595f45 100644 --- a/test/hardswish.cc +++ b/test/hardswish.cc @@ -13,7 +13,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/ibilinear-microkernel-tester.h b/test/ibilinear-microkernel-tester.h index f9fdecdbf8fd..c13e2a7f409f 100644 --- a/test/ibilinear-microkernel-tester.h +++ b/test/ibilinear-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/math.h" diff --git a/test/leaky-relu.cc b/test/leaky-relu.cc index dfba1ccce39b..0b9d2180442f 100644 --- a/test/leaky-relu.cc +++ b/test/leaky-relu.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/max-pooling-2d.cc b/test/max-pooling-2d.cc index 241ec93ab46f..2c89bb58efcb 100644 --- a/test/max-pooling-2d.cc +++ b/test/max-pooling-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" diff --git a/test/maxpool-microkernel-tester.h b/test/maxpool-microkernel-tester.h index faedb72b52b3..7587b033cd5f 100644 --- a/test/maxpool-microkernel-tester.h +++ b/test/maxpool-microkernel-tester.h @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/negate.cc b/test/negate.cc index 9ae0008b3f9c..a36ef8158441 100644 --- a/test/negate.cc +++ b/test/negate.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/packing.cc b/test/packing.cc index eeebc0946d64..4c6bd1910bdd 100644 --- a/test/packing.cc +++ b/test/packing.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack/math.h" #include "xnnpack/microkernel-utils.h" #include "xnnpack/microparams-init.h" diff --git a/test/prelu-microkernel-tester.h b/test/prelu-microkernel-tester.h index 10ec6e4e1f5a..b50bab36775e 100644 --- a/test/prelu-microkernel-tester.h +++ b/test/prelu-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/prelu.cc b/test/prelu.cc index 5f4a65daf611..698eee99a344 100644 --- a/test/prelu.cc +++ b/test/prelu.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/raddstoreexpminusmax-microkernel-tester.h b/test/raddstoreexpminusmax-microkernel-tester.h index 793e161a8ed4..4ea512f71a38 100644 --- a/test/raddstoreexpminusmax-microkernel-tester.h +++ b/test/raddstoreexpminusmax-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/rdsum-microkernel-tester.h b/test/rdsum-microkernel-tester.h index dc08fbf87008..0dc578047779 100644 --- a/test/rdsum-microkernel-tester.h +++ b/test/rdsum-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/reciprocal-square-root.cc b/test/reciprocal-square-root.cc index 8f7c6e1baae2..b5734767615a 100644 --- a/test/reciprocal-square-root.cc +++ b/test/reciprocal-square-root.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/rsum-microkernel-tester.h b/test/rsum-microkernel-tester.h index dee79bb9f0c3..9f2c94b94194 100644 --- a/test/rsum-microkernel-tester.h +++ b/test/rsum-microkernel-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/scaled-dot-product-attention.cc b/test/scaled-dot-product-attention.cc index 0d6fd61e3810..c87826087836 100644 --- a/test/scaled-dot-product-attention.cc +++ b/test/scaled-dot-product-attention.cc @@ -17,7 +17,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/sigmoid.cc b/test/sigmoid.cc index 82ffa4db6237..df1935063ef0 100644 --- a/test/sigmoid.cc +++ b/test/sigmoid.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/softmax.cc b/test/softmax.cc index b17134654a2e..458638b77a90 100644 --- a/test/softmax.cc +++ b/test/softmax.cc @@ -12,7 +12,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/space-to-depth-2d.cc b/test/space-to-depth-2d.cc index 9fc996e37774..66869423ad6e 100644 --- a/test/space-to-depth-2d.cc +++ b/test/space-to-depth-2d.cc @@ -13,7 +13,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/spmm-microkernel-tester.h b/test/spmm-microkernel-tester.h index 4ffa7cf6f89f..bd8a115cd2f4 100644 --- a/test/spmm-microkernel-tester.h +++ b/test/spmm-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" diff --git a/test/square-root.cc b/test/square-root.cc index 1ed486e0db3e..23ef94f803a8 100644 --- a/test/square-root.cc +++ b/test/square-root.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/square.cc b/test/square.cc index f183bae56039..4fbc26b979d7 100644 --- a/test/square.cc +++ b/test/square.cc @@ -11,7 +11,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-constant-pad.cc b/test/static-constant-pad.cc index c9ca1080777f..77f0d813d7f2 100644 --- a/test/static-constant-pad.cc +++ b/test/static-constant-pad.cc @@ -13,7 +13,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" diff --git a/test/static-expand-dims.cc b/test/static-expand-dims.cc index d608b5dddcd2..5fb684406785 100644 --- a/test/static-expand-dims.cc +++ b/test/static-expand-dims.cc @@ -15,7 +15,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-mean.cc b/test/static-mean.cc index 92b7a9d309d7..a55f63c4acf1 100644 --- a/test/static-mean.cc +++ b/test/static-mean.cc @@ -16,7 +16,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/common.h" diff --git a/test/static-reshape.cc b/test/static-reshape.cc index 0745cbd2273e..6107acbbc04c 100644 --- a/test/static-reshape.cc +++ b/test/static-reshape.cc @@ -15,7 +15,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-resize-bilinear-2d.cc b/test/static-resize-bilinear-2d.cc index 250c048995b2..6ea0dabddd2a 100644 --- a/test/static-resize-bilinear-2d.cc +++ b/test/static-resize-bilinear-2d.cc @@ -14,7 +14,6 @@ #include // For std::vector. #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-slice.cc b/test/static-slice.cc index 81c25ea0d0a1..e56e239a2a4e 100644 --- a/test/static-slice.cc +++ b/test/static-slice.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/static-transpose.cc b/test/static-transpose.cc index 52dc729c22e3..d5a92cd7cbed 100644 --- a/test/static-transpose.cc +++ b/test/static-transpose.cc @@ -14,7 +14,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/subgraph-fp16.cc b/test/subgraph-fp16.cc index 7653908077bf..ae25e6e6d863 100644 --- a/test/subgraph-fp16.cc +++ b/test/subgraph-fp16.cc @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/node-type.h" diff --git a/test/tanh-operator-tester.h b/test/tanh-operator-tester.h index f43e5d7e9971..755a6def164f 100644 --- a/test/tanh-operator-tester.h +++ b/test/tanh-operator-tester.h @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "replicable_random_device.h" diff --git a/test/tanh.cc b/test/tanh.cc index 2d01af0c5348..bd0cf712cb39 100644 --- a/test/tanh.cc +++ b/test/tanh.cc @@ -10,7 +10,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" diff --git a/test/unary-operator-tester.cc b/test/unary-operator-tester.cc index 74912831187f..f6d6545eb452 100644 --- a/test/unary-operator-tester.cc +++ b/test/unary-operator-tester.cc @@ -19,7 +19,6 @@ #include #include -#include #include "xnnpack.h" #include "replicable_random_device.h" diff --git a/test/vbinary-microkernel-tester.cc b/test/vbinary-microkernel-tester.cc index 00515ad57132..f4a48916660c 100644 --- a/test/vbinary-microkernel-tester.cc +++ b/test/vbinary-microkernel-tester.cc @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams-init.h" diff --git a/test/vcmul-microkernel-tester.h b/test/vcmul-microkernel-tester.h index 4c7461580b95..4d875874c38e 100644 --- a/test/vcmul-microkernel-tester.h +++ b/test/vcmul-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/isa-checks.h" #include "xnnpack/microfnptr.h" diff --git a/test/vcvt-microkernel-tester.cc b/test/vcvt-microkernel-tester.cc index e17d555edc5c..44129ffac4c0 100644 --- a/test/vcvt-microkernel-tester.cc +++ b/test/vcvt-microkernel-tester.cc @@ -20,7 +20,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" diff --git a/test/vmulcaddc-microkernel-tester.h b/test/vmulcaddc-microkernel-tester.h index c5f15d12aaf8..13afb0cb58e3 100644 --- a/test/vmulcaddc-microkernel-tester.h +++ b/test/vmulcaddc-microkernel-tester.h @@ -15,7 +15,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/aligned-allocator.h" #include "xnnpack/microfnptr.h" diff --git a/test/vunary-microkernel-tester.cc b/test/vunary-microkernel-tester.cc index ffa6e29a5fed..07a7ad6e91bc 100644 --- a/test/vunary-microkernel-tester.cc +++ b/test/vunary-microkernel-tester.cc @@ -17,7 +17,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/common.h" #include "xnnpack/microfnptr.h" diff --git a/test/vunary-microkernel-tester.h b/test/vunary-microkernel-tester.h index 8293d2769c96..a65b66465709 100644 --- a/test/vunary-microkernel-tester.h +++ b/test/vunary-microkernel-tester.h @@ -18,7 +18,6 @@ #include #include -#include #include "xnnpack.h" #include "xnnpack/microfnptr.h" #include "replicable_random_device.h"